Dimension reduction of double-wedge dataset

Here we use the double-wedge (or sqrt) model to compare some of the dimension reduction methods in a time-homogeneous setting.

[1]:
from matplotlib import pyplot as plt
import numpy as np
from tqdm.notebook import tqdm

from deeptime.data import sqrt_model

First we generate the training data with the binary classification according to wedge (dtraj) and two-dimensional state vectors (sqrt). Goal is to reconstruct dtraj from traj.

[2]:
dtraj, traj = sqrt_model(n_samples=1000, seed=777)

We can define the back-transform to make the drawn trajectory frames linearly separable according to state:

[3]:
def back_transform(x):
    y = np.copy(x)
    y[:, 1] = y[:, 1] - np.sqrt(np.abs(y[:, 0]))
    return y

The dataset is set up as follows: The discrete trajectory (upper left panel) has corresponding two-dimensional and linearly separable datapoints (upper right panel) which are subsequently transformed through a squareroot-like function to produce the double-wedge (lower left panel).

[4]:
X, Y = np.meshgrid(
    np.linspace(np.min(traj[:, 0]), np.max(traj[:, 0]), 100),
    np.linspace(np.min(traj[:, 1]), np.max(traj[:, 1]), 100),
)
kde_input = np.dstack((X, Y)).reshape(-1, 2)

import scipy.stats as stats
kernel = stats.gaussian_kde(traj.T, bw_method=.1)
Z = kernel(kde_input.T).reshape(X.shape)

f, axes = plt.subplots(2, 2, figsize=(16, 12))
ax1 = axes[0, 0]
ax1.plot(dtraj[:500])
ax1.set_title('Discrete trajectory')
ax1.set_xlabel('time (a.u.)')
ax1.set_ylabel('state')

ax2 = axes[0, 1]
ax2.set_title('Data prior transformation')
scatter = ax2.scatter(*back_transform(traj).T, c=dtraj, cmap='BrBG')
legend1 = ax2.legend(*scatter.legend_elements(), title="States", loc='lower left')
ax2.add_artist(legend1)

ax3 = axes[1, 0]
ax3.set_title('Data')
scatter = ax3.scatter(*traj.T, c=dtraj, cmap='BrBG')
legend1 = ax3.legend(*scatter.legend_elements(), title="States", loc='lower left')
ax3.add_artist(legend1)

ax4 = axes[1, 1]
cm = ax4.contourf(X, Y, Z)
plt.colorbar(cm, ax=ax4);
ax4.set_title('Heatmap of observations');
../../_images/notebooks_examples_sqrt-model-dimrx_7_0.png

Let us now compare the different methods. To this end we first define a plotting routine.

[5]:
def plot_method(transform, score, score_std):
    r""" Plotting method for the different models.

    Parameters
    ----------
    transform : Callable
        Function pointer that projects data into a one-dimensional space.
    score
        VAMP-2 score
    score_std
        VAMP-2 score standard deviation
    """
    from deeptime.clustering import KMeans
    from matplotlib.colors import Normalize

    feat = transform(traj)
    feat_cc = KMeans(2).fit_transform(feat)

    xmax = np.max(np.abs(traj[:, 0]))
    ymin = np.min(traj[:, 1])
    ymax = np.max(traj[:, 1])
    vmin = np.min(feat)
    vmax = np.max(feat)
    vmag = np.max([np.abs(vmin), np.abs(vmax)])

    scatter_size = 35

    grid = np.meshgrid(np.linspace(-xmax-1, xmax+1, 500), np.linspace(ymin-1, ymax+1, 500))
    xy = np.dstack(grid).reshape(-1, 2)
    levels = np.linspace(0, vmag, num=5)

    z = np.concatenate([transform(xxyy) for xxyy in np.array_split(xy, 3)])
    z00 = transform(np.array([0, -1]).reshape(-1, 2))
    z = np.clip(z, vmin, vmax)

    if np.mean(np.abs((feat_cc[:500].squeeze() - dtraj[:500]))) > .3:
        feat_cc = 1. - feat_cc
    if z00 < 0:
        z *= -1.
    z *= -1

    cmap = 'BrBG'
    f = plt.figure(figsize=(6, 6), constrained_layout=True)
    height_ratios = [2, 1]
    gs = f.add_gridspec(2, 1, height_ratios=height_ratios)

    ax = f.add_subplot(gs[0, 0])
    ax_contour = ax

    norm = Normalize(-vmag, vmag)
    levels_fine = np.linspace(0, levels[1], num=5)
    levels = np.concatenate([-levels[2:][::-1], -levels_fine[1:][::-1], levels_fine, levels[2:]])

    cb = ax.contourf(grid[0], grid[1], z.reshape(grid[0].shape), levels=levels,
                     cmap=cmap, norm=norm, extend='both')

    ticks = levels
    cb = f.colorbar(cb, ax=ax, ticks=ticks[1::3])
    cb.ax.set_yticklabels([f"{x:.2f}" for x in ticks[1::3]])

    ix_0 = np.where((feat_cc == 0) & (feat_cc == dtraj))[0]
    ix_1 = np.where((feat_cc == 1) & (feat_cc == dtraj))[0]

    ax.scatter(*traj[ix_0].T,
               color=plt.cm.plasma(.3),
               s=scatter_size, linewidths=.1,
               marker='o', cmap=cmap, edgecolors='black',
               zorder=500, alpha=.5, label=f'Estimated state {0}')
    ax.scatter(*traj[ix_1].T,
               color=plt.cm.plasma(.8),
               s=scatter_size, linewidths=.1,
               marker='o', cmap=cmap, edgecolors='black',
               zorder=500, alpha=.5, label=f'Estimated state {1}')

    ix_0 = np.where((feat_cc == 0) & (feat_cc != dtraj))[0]
    ix_1 = np.where((feat_cc == 1) & (feat_cc != dtraj))[0]

    if len(ix_0) > 0:
        ax.scatter(*traj[ix_0].T,
                   color=plt.cm.plasma(.3),
                   s=scatter_size, linewidths=.5,
                   marker='o', cmap=cmap, edgecolors='red',
                   zorder=500, alpha=.5)
    if len(ix_1) > 0:
        ax.scatter(*traj[ix_1].T,
               color=plt.cm.plasma(.8),
               s=scatter_size, linewidths=.5,
               marker='o', cmap=cmap, edgecolors='red',
               zorder=500, alpha=.5)

    ax.legend()

    props = dict(boxstyle='square', facecolor='white', alpha=.8)
    ax.text(0.5, 0.08, r"$s={:.2f}\pm {:.2f}$".format(score, score_std), transform=ax.transAxes,
            verticalalignment='top', horizontalalignment='center', bbox=props, fontsize='x-small')

    ax = f.add_subplot(gs[1, 0])
    n_proj = 200
    mismatch = 100. * float(np.count_nonzero(feat_cc - dtraj)) / float(len(dtraj))
    rescaled_proj = .5*(feat[:n_proj] + 1)
    feat_cc = feat_cc[:n_proj]
    if np.mean(np.abs((rescaled_proj.squeeze() - dtraj[:n_proj]))) > .4:
        rescaled_proj = 1. - rescaled_proj
    xs = np.arange(len(rescaled_proj))
    ax.plot(xs, rescaled_proj, alpha=.3, color='C0')
    ax.plot(xs, feat_cc, color='C0', label=r'Estimated (${:.1f}\%$ mismatch)'.format(mismatch))
    ax.plot(xs, dtraj[:n_proj], color='C1', linestyle='dotted', label='Ground truth')
    ax.set_xlabel('time (a.u.)')
    ax.set_ylabel(r'$\chi (x_t)$')
    ax.set_ylim([-.5, 1.5])
    ax.set_yticks([0, 1])
    ax.set_title(r"${:.1f}\%$ acc.".format(100-mismatch))
    ax.legend(loc="lower center", ncol=2)

    return ax_contour, ax

TICA

We use the linear TICA method.

[6]:
from deeptime.decomposition import vamp_score_cv
from deeptime.decomposition import TICA

estimator = TICA(lagtime=1, dim=1)
tica = estimator.fit_fetch(traj)
scores = vamp_score_cv(estimator, traj, blocksize=100, random_state=42)

plot_method(tica.transform, scores.mean(), scores.std());
../../_images/notebooks_examples_sqrt-model-dimrx_11_0.png

EDMD

We use EDMD with a polynomial ansatz basis of degree up to \(2\).

[7]:
from deeptime.decomposition import EDMD, VAMP
from deeptime.basis import Monomials

basis = Monomials(p=2, d=2)
edmd = EDMD(basis=basis).fit_fetch(traj, lagtime=1)

vamp_edmd = VAMP(lagtime=1, dim=1,
                 observable_transform=lambda x: edmd.transform(x).real)
scores = vamp_score_cv(vamp_edmd, traj, blocksize=100, random_state=42)

plot_method(lambda x: edmd.transform(x)[..., 1].real, scores.mean(), scores.std())
[7]:
(<AxesSubplot:>,
 <AxesSubplot:title={'center':'$97.6\\%$ acc.'}, xlabel='time (a.u.)', ylabel='$\\chi (x_t)$'>)
../../_images/notebooks_examples_sqrt-model-dimrx_13_1.png

VAMP with backtransform

We use VAMP with knowledge about the ground truth to make data linearly separable.

[8]:
vamp_back = VAMP(lagtime=1, dim=1,
                 observable_transform=back_transform).fit(traj)
scores = vamp_score_cv(vamp_back, traj, blocksize=100, random_state=42)
plot_method(vamp_back.transform, scores.mean(), scores.std());
../../_images/notebooks_examples_sqrt-model-dimrx_15_0.png

Kernel EDMD

We use kernel EDMD with a Gaussian kernel and parameter optimization in whitened space.

[9]:
from deeptime.covariance import Covariance
cov = Covariance(lagtime=1, compute_c0t=True).fit(traj).fetch_model()
whitened_traj = cov.whiten(traj)
[10]:
from deeptime.kernels import GaussianKernel
from deeptime.decomposition import KernelEDMD

kernel = GaussianKernel(1.4)
kedmd = KernelEDMD(kernel=kernel, epsilon=1.6e-3, n_eigs=2).fit_fetch(whitened_traj, lagtime=1)

kedmd_transf = lambda x: (kedmd.transform(cov.whiten(x)).real).reshape(-1, 2)
vamp_kedmd = VAMP(lagtime=1, dim=1, epsilon=1e-16, observable_transform=kedmd_transf).fit(traj)
scores = vamp_score_cv(vamp_kedmd, traj, blocksize=100, random_state=42)

plot_method(vamp_kedmd.transform, scores.mean(), scores.std());
../../_images/notebooks_examples_sqrt-model-dimrx_18_0.png

Kernel CCA

We use kernel CCA with optimized kernel parameters.

[11]:
from deeptime.decomposition import KernelCCA

from scipy.optimize import Bounds, minimize

def objective(params):
    bw, eps = params
    kcca = KernelCCA(GaussianKernel(bw), 1, epsilon=eps).fit((traj[:-1], traj[1:])).fetch_model()
    vamp_kcca = VAMP(lagtime=1, dim=1,
                 observable_transform=lambda x: kcca.transform(x).real).fit(traj).fetch_model()
    return -vamp_kcca.score(2)

bounds = Bounds([1e-2, 1e-2], [1e-0, 1e-0])
result = minimize(objective, x0=[.8, .3], bounds=bounds, method='SLSQP')
bw, eps = result.x
print(f"Estimated bandwidth {bw:.3e} and regularization parameter {eps:.3e}.")
Estimated bandwidth 8.551e-01 and regularization parameter 1.000e-02.
[12]:
kernel = GaussianKernel(bw)
kcca = KernelCCA(kernel, 1, epsilon=eps).fit_fetch(traj, lagtime=1)

kcca_transform = lambda x: kcca.transform(x).real
vamp_kcca = VAMP(lagtime=1, dim=1, observable_transform=kcca_transform).fit(traj)
scores = vamp_score_cv(vamp_kcca, traj, blocksize=100, random_state=42)

plot_method(vamp_kcca.transform, scores.mean(), scores.std());
../../_images/notebooks_examples_sqrt-model-dimrx_21_0.png

VAMPNets

We use VAMPNets with a multilayer preceptron (MLP) architecture.

[13]:
import torch.nn as nn
from torch.utils.data import DataLoader

from deeptime.util.data import TrajectoryDataset
from deeptime.util.torch import MLP
from deeptime.decomposition.deep import VAMPNet

dataset = TrajectoryDataset(1, traj.astype(np.float32))
estimator = VAMPNet(lobe=MLP(units=[traj.shape[1], 15, 10, 10, 5, 1], nonlinearity=nn.ReLU),
                    learning_rate=1e-3)
loader_train = DataLoader(dataset, batch_size=128, shuffle=True)
vampnet = estimator.fit_fetch(loader_train, n_epochs=170, progress=tqdm)
[14]:
vamp_vampnet = VAMP(lagtime=1, dim=1, observable_transform=vampnet).fit(traj)
scores = vamp_score_cv(vamp_vampnet, traj, blocksize=100, random_state=42)

plot_method(vamp_vampnet.transform, scores.mean(), scores.std());
../../_images/notebooks_examples_sqrt-model-dimrx_24_0.png

Markov state models

We use MSMs with a box discretization of various degrees of fineness.

[15]:
from deeptime.clustering import BoxDiscretization
from deeptime.markov import TransitionCountEstimator
from deeptime.markov.msm import MaximumLikelihoodMSM

def estimate(n_boxes):
    clustering = BoxDiscretization(dim=2, n_boxes=n_boxes).fit_fetch(traj)
    counts = TransitionCountEstimator(lagtime=1, count_mode='sliding', n_states=clustering.n_clusters).fit_fetch(clustering.transform(traj)).submodel_largest()
    msm = MaximumLikelihoodMSM(lagtime=1).fit_fetch(counts)

    def transform(xy):
        dtr = clustering.transform(xy)
        dtr = msm.count_model.transform_discrete_trajectories_to_submodel(dtr)
        evr = msm.eigenvectors_right(2)
        invalids = dtr == -1
        output = np.empty(shape=(len(dtr), 2))
        output[invalids, :] = 0
        output[~invalids, :] = evr[dtr[~invalids]]
        return output

    vamp_msm = VAMP(lagtime=1, dim=1, observable_transform=transform).fit(traj)
    scores = vamp_score_cv(vamp_msm, traj, blocksize=10, random_state=42)

    ax, _ = plot_method(vamp_msm.transform, scores.mean(), scores.std())
    ax.scatter(*clustering.cluster_centers.T, color='black', marker='+')
[16]:
estimate(5)
../../_images/notebooks_examples_sqrt-model-dimrx_27_0.png
[17]:
estimate(7)
../../_images/notebooks_examples_sqrt-model-dimrx_28_0.png
[18]:
estimate(15)
../../_images/notebooks_examples_sqrt-model-dimrx_29_0.png