Coherent set comparison on Bickley jet

First, we set up two methods which allow us to simulate the Bickley jet forward and backward in time.

import numpy as np
from import BickleyJet, bickley_jet

state = np.random.RandomState(seed=123)  # global random state

def draw_initial_positions(n):
    # draws n initial positions in the system's domain
    X = np.vstack((state.uniform(0, 20, (n,)),
                   state.uniform(-3, 3, (n, ))))
    return X.T

def forward_transform(X, h=1e-3):
    # simulator with positive time-step
    simulator = BickleyJet(h=h, n_steps=int(1. / h / 10))
    # returns a trajectory evolving X to t=40
    return simulator.trajectory(t0=0, x0=X, length=401)

def back_transform(Y, h=1e-3):
    # simulator with negative time-step
    simulator = BickleyJet(h=-h, n_steps=int(1. / h / 10))
    # returns a trajectory evolving Y to t=0
    return simulator.trajectory(t0=40, x0=Y, length=401)

Now we can simulate the Bickley jet using particle configurations.

Xinit_train = draw_initial_positions(3000)
traj_train = forward_transform(Xinit_train)
print("Simulated training data with {} particles for {} "
      "timesteps in {} dimensions.".format(*traj_train.shape))

Xinit_test = draw_initial_positions(15000)
traj_test = forward_transform(Xinit_test)
print("Simulated test data with {} particles for {} "
      "timesteps in {} dimensions.".format(*traj_test.shape))
Simulated training data with 3000 particles for 401 timesteps in 2 dimensions.
Simulated test data with 15000 particles for 401 timesteps in 2 dimensions.
from IPython.display import HTML
import matplotlib.pyplot as plt
import matplotlib.animation as animation

f, ax = plt.subplots(1, 1, figsize=(12, 4))
ax.set_xlim([0, 20])
ax.set_ylim([-4, 4])
c = traj_train[:, 0, 0] / traj_train[:, 0, 0].max()
handle = ax.scatter(*traj_train[:, 0].T, c=c)

def step(t):
    handle.set_offsets(traj_train[:, t])
    return [handle]

ani = animation.FuncAnimation(f, step, interval=80, blit=True, repeat=False,
plt.close()  # prevent figure from showing in output

Since this is a time-inhomogeneous system and we are interested in the coherent sets that can be found for initial time \(s=0\) and final time \(t=40\), we set up appropriate datasets.

from import TimeLaggedDataset

ds_train = TimeLaggedDataset(data=traj_train[:, 0], data_lagged=traj_train[:, -1])
ds_test = TimeLaggedDataset(data=traj_test[:, 0], data_lagged=traj_test[:, -1])

Kernel CCA

We can now fit a kernel CCA model to the particle data mapping initial time positions to final time positions.

from deeptime.decomposition import KernelCCA
from deeptime.kernels import GaussianKernel

bw, eps = 0.57863921, 0.00563756
kcca = KernelCCA(GaussianKernel(bw), n_eigs=9, epsilon=eps).fit_fetch(ds_train)

Now we plot the evaluation of the estimated singular functions on the domain and already observe some of the coherent structure.

def plot_singular_functions(transform):
    figsize = (16, 4)

    grid = np.meshgrid(np.linspace(0, 20, 300), np.linspace(-3, 3, 180))
    xy = np.dstack(grid).reshape(-1, 2)
    z = transform(xy)

    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(ncols=3, nrows=2)

    for row in range(2):
        for col in range(3):
            ix = col + 3*row
            if ix < z.shape[1]:
                ax = fig.add_subplot(gs[row, col])
                f = z[:, ix].reshape(grid[0].shape)
                f /= np.max(np.abs(f))
                ax.contourf(grid[0], grid[1], f, levels=15, cmap='bwr', vmin=-1, vmax=1)
                ax.set_title(f"Function {ix+1}")

plot_singular_functions(lambda x: kcca.transform(x).real)

We cluster the test data in the projected space to find crisp coherent set assignments.

from deeptime.clustering import KMeans

def cluster(X, n_iter=500, n_clusters=9):
    clusterings = [KMeans(n_clusters=n_clusters).fit_fetch(X) for _ in range(n_iter)]
    return min(clusterings, key=lambda c: c.inertia)

cluster_kcca = cluster(kcca.transform(

Plotting the particles colored according to cluster center membership reveals the coherent set memberships.

memberships_transform_kcca = lambda x: cluster_kcca.transform(kcca.transform(x).real)
memberships = memberships_transform_kcca(

f, ax = plt.subplots(1, 1)
ax.scatter(*, c=memberships)

Let us now add some noise to the final frame and transform it back in time to initial time to estimate the coherence score.

traj_test_backward = back_transform(traj_test[:, -1] + state.normal(scale=1e-1, size=traj_test[:, -1].shape))

To this end, we compare the cluster assignments and find the ones which changed compared to the initial state above. With this we can compute the expected probability that a particle stays within its own state.

def coherence_score(membership_transform):
    from deeptime.markov import TransitionCountEstimator
    from deeptime.markov.msm import MaximumLikelihoodMSM

    reference = membership_transform(Xinit_test)
    backtransformed = membership_transform(traj_test_backward[:, -1])

    # particle indices which did not get mapped to the same cluster center
    mismatches = np.where(reference != backtransformed)[0]

    dtrajs = []
    for j in range(len(backtransformed)):
        dtrajs.append(np.array([reference[j], backtransformed[j]]))

    states, counts = np.unique(reference, return_counts=True)
    cm = TransitionCountEstimator(1, 'sliding', n_states=len(states)).fit_fetch(dtrajs)
    msm = MaximumLikelihoodMSM(reversible=False, allow_disconnected=True).fit_fetch(cm)
    score = 0
    for i in range(len(states)):
        state = states[i]
        score += msm.transition_matrix[state, state] * (float(counts[i]) / float(len(reference)))

    return score, mismatches

def plot_mismatches(mismatches, membership_transform):
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

    X = Xinit_test[:5000]
    Y = ds_test.data_lagged[:5000]
    mismatches = mismatches[np.where(mismatches < 5000)[0]]

    ref = membership_transform(X)
    ix_good = np.arange(len(X))
    ix_good = np.setdiff1d(ix_good, mismatches, assume_unique=True)

    ax1.scatter(*X[ix_good].T, c=ref[ix_good])
    ax1.scatter(*X[mismatches].T, c='red', s=10, zorder=500, linewidths=.5)
    ax1.set_ylim([-4, 4])

    ax2.scatter(*Y[ix_good].T, c=ref[ix_good])
    ax2.scatter(*Y[mismatches].T, c='red', s=10, zorder=500, linewidths=.5)
    ax2.set_ylim([-4, 4])
score_kcca, mismatches_kcca = coherence_score(memberships_transform_kcca)
print(f"Coherence score: {score_kcca:.2f}")
plot_mismatches(mismatches_kcca, memberships_transform_kcca)
Coherence score: 0.86


For VAMPNets we first transform the data into three-dimensional space to account for quasi-periodicity.

import torch
import torch.nn as nn
from deeptime.decomposition.deep import VAMPNet

if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.backends.cudnn.benchmark = True
    device = torch.device("cpu")

Now we define the VAMPNet lobe.

lobe = nn.Sequential(
    nn.Linear(3, 256), nn.ELU(), nn.Dropout(),
    nn.Linear(256, 512), nn.ELU(), nn.Dropout(),
    nn.Linear(512, 128), nn.ELU(), nn.Dropout(),
    nn.Linear(128, 128), nn.ELU(), nn.Dropout(),
    nn.Linear(128, 9)

Training is performed as follows:

First, we define modified datasets. Since the data fits into GPU memory we already put it on the device and also transform it into three-dimensional space to deal with quasi-periodicity.

# Dataset which transforms X and a slightly noisy version of Y into 3-dimensional space
class BickleyJet3DTorchDS:
    def __init__(self, ds):
        X, Y = ds[:]
        self.X = torch.tensor(BickleyJet.to_3d(X), dtype=torch.float32, device=device)
        self.Y = torch.tensor(BickleyJet.to_3d(Y), dtype=torch.float32, device=device)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, ix):
        x, y = self.X[ix], self.Y[ix]
        return x, y

ds_train_3d = BickleyJet3DTorchDS(ds_train.astype(np.float32))
ds_test_3d = BickleyJet3DTorchDS(ds_test.astype(np.float32))

Now we create loaders. We can speed up training by using batch samplers.

from tqdm.notebook import tqdm
from import DataLoader, BatchSampler, RandomSampler, SequentialSampler

collate_fn = lambda x: x[0]
train_sampler = BatchSampler(RandomSampler(ds_train_3d), batch_size=2048, drop_last=True)
loader_train = DataLoader(ds_train_3d, sampler=train_sampler, collate_fn=collate_fn)
val_sampler = BatchSampler(SequentialSampler(ds_test_3d), batch_size=2048, drop_last=False)
loader_val = DataLoader(ds_test_3d, sampler=val_sampler, collate_fn=collate_fn)

While VAMPNets possess a fit() method, we will implement the training loop ourselves to make use of checkpointing and learning rate scheduling.

from pathlib import Path
from deeptime.decomposition.deep import vamp_score
from deeptime.util.torch import CheckpointManager
from torch.optim.lr_scheduler import ReduceLROnPlateau

checkpoints_dir = Path('.') / 'checkpoints'
checkpoints = CheckpointManager(output_dir=checkpoints_dir)
opt = torch.optim.Adam(lobe.parameters(), 3e-4)
scheduler = ReduceLROnPlateau(opt, mode='max', patience=200)

scores = [[], []]

n_epochs = 8000
for epoch in tqdm(range(n_epochs)):
    train_scores = []
    for X, Y in loader_train:
        loss = -vamp_score(lobe(X), lobe(Y), mode='regularize')
    train_score = np.mean(train_scores)

    val_scores = []
    for Xval, Yval in loader_val:
        val_score = vamp_score(lobe(Xval), lobe(Yval), mode='regularize')
    val_score = torch.mean(torch.stack(val_scores))
    checkpoints.step(epoch, val_score, models=dict(lobe=lobe))

We plot training and validation scores. Due to dropout it is expected that the validation score is higher than the training score.

plt.semilogx(scores[0], label='training')
plt.semilogx(scores[1], label='validation')

We load the model with the best validation score.

from deeptime.decomposition.deep import VAMPNetModel

lobe.load_state_dict(torch.load(checkpoints_dir / 'best.ckpt')['lobe'])
vampnet = VAMPNetModel(lobe, device=device)

We can now use the trained VAMPNet model to create a VAMP model instance.

from deeptime.decomposition import VAMP

vamp_vampnet = VAMP(
    observable_transform=lambda x: vampnet.transform(BickleyJet.to_3d(x)), dim=9, epsilon=1e-12


Obtain a clustering:

cluster_vampnet = cluster(vamp_vampnet.transform(
memberships_transform_vampnet = lambda x: cluster_vampnet.transform(vamp_vampnet.transform(x))

memberships = memberships_transform_vampnet(

f, ax = plt.subplots(1, 1)
ax.scatter(*, c=memberships)

and approximate the coherence score.

score_vampnet, mismatches_vampnet = coherence_score(memberships_transform_vampnet)
print(f"Coherence score: {score_vampnet:.2f}")
plot_mismatches(mismatches_vampnet, memberships_transform_vampnet)
Coherence score: 0.76


Here we fit a VAMP model with a set of randomized feature functions.

class ChiRnd:

    def __init__(self, n=100, fan_in=3, scale=1, bias_var=5, out_dim=50):
        self.n_basis = n
        self.out_dim = out_dim
        self.W = np.random.normal(scale=scale, size=(fan_in, self.n_basis))
        self.b = np.random.uniform(-bias_var, bias_var, size=(self.n_basis,))

        self.W2 = np.random.normal(scale=scale, size=(self.n_basis, out_dim))
        self.b2 = np.random.uniform(-bias_var, bias_var, size=(out_dim,))

    def nonlinearity(self, x):
        return np.exp(-x*x)

    def __call__(self, x):
        return self.nonlinearity(BickleyJet.to_3d(x) @ self.W + self.b) @ self.W2 + self.b2

chi = ChiRnd()
vamp = VAMP(lagtime=1, dim=9, observable_transform=chi).fit_fetch(ds_train)
cluster_vamp = cluster(vamp.transform(

memberships_transform_vamp = lambda x: cluster_vamp.transform(vamp.transform(x))
memberships = memberships_transform_vamp(

f, ax = plt.subplots(1, 1)
ax.scatter(*, c=memberships)
score_vamp, mismatches_vamp = coherence_score(memberships_transform_vamp)
print(f"Coherence score: {score_vamp:.2f}")
plot_mismatches(mismatches_vamp, memberships_transform_vamp)
Coherence score: 0.74


from deeptime.decomposition import KVAD

kvad = KVAD(kernel=GaussianKernel(1.),
            lagtime=1, epsilon=1e-5, dim=9,

cluster_kvad = cluster(kvad.transform(

memberships_transform_kvad = lambda x: cluster_kvad.transform(kvad.transform(x))
memberships = memberships_transform_kvad(

f, ax = plt.subplots(1, 1)
ax.scatter(*, c=memberships)
score_kvad, mismatches_kvad = coherence_score(memberships_transform_kvad)
print(f"Coherence score: {score_kvad:.2f}")
plot_mismatches(mismatches_kvad, memberships_transform_kvad)
Coherence score: 0.73


Similar to VAMPNets we optimize the KVAD score with a neural network to find a KVAD-score optimal parametrization.

from deeptime.decomposition.deep import kvad_score
from deeptime.kernels import TorchGaussianKernel
kernel = TorchGaussianKernel(.5)

kvadnets_lobe = nn.Sequential(
    nn.Linear(3, 256), nn.ELU(), nn.Dropout(),
    nn.Linear(256, 512), nn.ELU(), nn.Dropout(),
    nn.Linear(512, 128), nn.ELU(), nn.Dropout(),
    nn.Linear(128, 128), nn.ELU(), nn.Dropout(),
    nn.Linear(128, 9)

from tqdm.notebook import tqdm
from import DataLoader, BatchSampler, RandomSampler, SequentialSampler

collate_fn = lambda x: x[0]
train_sampler = BatchSampler(RandomSampler(ds_train_3d), batch_size=512, drop_last=True)
loader_train = DataLoader(ds_train_3d, sampler=train_sampler, collate_fn=collate_fn)
val_sampler = BatchSampler(SequentialSampler(ds_test_3d), batch_size=512, drop_last=False)
loader_val = DataLoader(ds_test_3d, sampler=val_sampler, collate_fn=collate_fn)

opt = torch.optim.Adam(kvadnets_lobe.parameters(), 3e-4)

scores = [[], []]

n_epochs = 15000
for epoch in tqdm(range(n_epochs)):
    epoch_train_scores = []
    for batch_0, batch_t, in loader_train:
        chi_0 = kvadnets_lobe(batch_0)
        loss = -kvad_score(chi_0, batch_t, kernel=kernel)

    epoch_val_scores = []
    for batch_0, batch_t, in loader_val:
        chi_0 = kvadnets_lobe(batch_0)
        val_score = kvad_score(chi_0, batch_t, kernel=kernel)


# definition of the estimated transformation
def chi_kvadnet(X):
    X = BickleyJet.to_3d(X)
    with torch.no_grad():
        chi_X = kvadnets_lobe(torch.from_numpy(X.astype(np.float32)).to(device=device))
    return chi_X.cpu().numpy()
plt.semilogx(scores[0], label='training')
plt.semilogx(scores[1], label='validation')
kvadnet = KVAD(
    kernel=kernel, observable_transform=chi_kvadnet,

cluster_kvadnet = cluster(kvadnet.transform(

memberships_transform_kvadnet = lambda x: cluster_kvadnet.transform(kvadnet.transform(x))
memberships = memberships_transform_kvadnet(

f, ax = plt.subplots(1, 1)
ax.scatter(*, c=memberships)
score_kvadnet, mismatches_kvadnet = coherence_score(memberships_transform_kvadnet)
print(f"Coherence score: {score_kvadnet:.2f}")
plot_mismatches(mismatches_kvadnet, memberships_transform_kvadnet)
Coherence score: 0.88

Comparison of scores

We obtain the following:

def vamp2_score(transform):
    return VAMP(observable_transform=transform, dim=9).fit_fetch(ds_test).score(2)

def kvad_score(transform):
    return KVAD(GaussianKernel(0.5), observable_transform=transform, dim=9).fit_fetch(ds_test).score
import tabulate
        ["KVADNets", vamp2_score(kvadnet.transform), kvad_score(kvadnet.transform), score_kvadnet],
        ["Kernel CCA", vamp2_score(lambda x: kcca.transform(x).real), kvad_score(lambda x: kcca.transform(x).real), score_kcca],
        ["VAMPNets", vamp2_score(vamp_vampnet.transform), kvad_score(vamp_vampnet.transform), score_vampnet],
        ["VAMP", vamp2_score(vamp.transform), kvad_score(vamp.transform), score_vamp],
        ["KVAD", vamp2_score(kvad.transform), kvad_score(kvad.transform), score_kvad],
    headers=["Method", "VAMP-2 score", "KVAD score", "Coherence score"]
Method        VAMP-2 score    KVAD score    Coherence score
----------  --------------  ------------  -----------------
KVADNets           5.88749     0.0796252             0.8802
Kernel CCA         5.8305      0.0762447             0.8598
VAMPNets           7.07734     0.0749358             0.7622
VAMP               4.75647     0.0677209             0.738
KVAD               4.6162      0.0667439             0.7298