Coherent set comparison on Bickley jet

First, we set up two methods which allow us to simulate the Bickley jet forward and backward in time.

[1]:
import numpy as np
from deeptime.data import BickleyJet, bickley_jet

state = np.random.RandomState(seed=123)  # global random state

def draw_initial_positions(n):
    # draws n initial positions in the system's domain
    X = np.vstack((state.uniform(0, 20, (n,)),
                   state.uniform(-3, 3, (n, ))))
    return X.T

def forward_transform(X, h=1e-3):
    # simulator with positive time-step
    simulator = BickleyJet(h=h, n_steps=int(1. / h / 10))
    # returns a trajectory evolving X to t=40
    return simulator.trajectory(t0=0, x0=X, length=401)

def back_transform(Y, h=1e-3):
    # simulator with negative time-step
    simulator = BickleyJet(h=-h, n_steps=int(1. / h / 10))
    # returns a trajectory evolving Y to t=0
    return simulator.trajectory(t0=40, x0=Y, length=401)

Now we can simulate the Bickley jet using particle configurations.

[2]:
Xinit_train = draw_initial_positions(3000)
traj_train = forward_transform(Xinit_train)
print("Simulated training data with {} particles for {} "
      "timesteps in {} dimensions.".format(*traj_train.shape))

Xinit_test = draw_initial_positions(15000)
traj_test = forward_transform(Xinit_test)
print("Simulated test data with {} particles for {} "
      "timesteps in {} dimensions.".format(*traj_test.shape))
Simulated training data with 3000 particles for 401 timesteps in 2 dimensions.
Simulated test data with 15000 particles for 401 timesteps in 2 dimensions.
[3]:
from IPython.display import HTML
import matplotlib.pyplot as plt
import matplotlib.animation as animation

f, ax = plt.subplots(1, 1, figsize=(12, 4))
ax.set_xlim([0, 20])
ax.set_ylim([-4, 4])
c = traj_train[:, 0, 0] / traj_train[:, 0, 0].max()
handle = ax.scatter(*traj_train[:, 0].T, c=c)
ax.set_aspect('equal')

def step(t):
    handle.set_offsets(traj_train[:, t])
    return [handle]

ani = animation.FuncAnimation(f, step, interval=80, blit=True, repeat=False,
                              frames=traj_train.shape[1])
plt.close()  # prevent figure from showing in output
HTML(ani.to_html5_video())
[3]:

Since this is a time-inhomogeneous system and we are interested in the coherent sets that can be found for initial time \(s=0\) and final time \(t=40\), we set up appropriate datasets.

[4]:
from deeptime.util.data import TimeLaggedDataset

ds_train = TimeLaggedDataset(data=traj_train[:, 0], data_lagged=traj_train[:, -1])
ds_test = TimeLaggedDataset(data=traj_test[:, 0], data_lagged=traj_test[:, -1])

Kernel CCA

We can now fit a kernel CCA model to the particle data mapping initial time positions to final time positions.

[5]:
from deeptime.decomposition import KernelCCA
from deeptime.kernels import GaussianKernel

bw, eps = 0.57863921, 0.00563756
kcca = KernelCCA(GaussianKernel(bw), n_eigs=9, epsilon=eps).fit_fetch(ds_train)

Now we plot the evaluation of the estimated singular functions on the domain and already observe some of the coherent structure.

[6]:
def plot_singular_functions(transform):
    figsize = (16, 4)

    grid = np.meshgrid(np.linspace(0, 20, 300), np.linspace(-3, 3, 180))
    xy = np.dstack(grid).reshape(-1, 2)
    z = transform(xy)

    fig = plt.figure(figsize=figsize)
    gs = fig.add_gridspec(ncols=3, nrows=2)

    for row in range(2):
        for col in range(3):
            ix = col + 3*row
            if ix < z.shape[1]:
                ax = fig.add_subplot(gs[row, col])
                f = z[:, ix].reshape(grid[0].shape)
                f /= np.max(np.abs(f))
                ax.contourf(grid[0], grid[1], f, levels=15, cmap='bwr', vmin=-1, vmax=1)
                ax.set_title(f"Function {ix+1}")
                ax.set_aspect('equal')

plot_singular_functions(lambda x: kcca.transform(x).real)
../../_images/notebooks_examples_coherence-bickley-jet_10_0.png

We cluster the test data in the projected space to find crisp coherent set assignments.

[7]:
from deeptime.clustering import KMeans

def cluster(X, n_iter=500, n_clusters=9):
    clusterings = [KMeans(n_clusters=n_clusters).fit_fetch(X) for _ in range(n_iter)]
    return min(clusterings, key=lambda c: c.inertia)

cluster_kcca = cluster(kcca.transform(ds_test.data).real)

Plotting the particles colored according to cluster center membership reveals the coherent set memberships.

[8]:
memberships_transform_kcca = lambda x: cluster_kcca.transform(kcca.transform(x).real)
memberships = memberships_transform_kcca(ds_test.data)

f, ax = plt.subplots(1, 1)
ax.scatter(*ds_test.data.T, c=memberships)
ax.set_aspect('equal')
../../_images/notebooks_examples_coherence-bickley-jet_14_0.png

Let us now add some noise to the final frame and transform it back in time to initial time to estimate the coherence score.

[9]:
traj_test_backward = back_transform(traj_test[:, -1] + state.normal(scale=1e-1, size=traj_test[:, -1].shape))

To this end, we compare the cluster assignments and find the ones which changed compared to the initial state above. With this we can compute the expected probability that a particle stays within its own state.

[10]:
def coherence_score(membership_transform):
    from deeptime.markov import TransitionCountEstimator
    from deeptime.markov.msm import MaximumLikelihoodMSM

    reference = membership_transform(Xinit_test)
    backtransformed = membership_transform(traj_test_backward[:, -1])

    # particle indices which did not get mapped to the same cluster center
    mismatches = np.where(reference != backtransformed)[0]

    dtrajs = []
    for j in range(len(backtransformed)):
        dtrajs.append(np.array([reference[j], backtransformed[j]]))

    states, counts = np.unique(reference, return_counts=True)
    cm = TransitionCountEstimator(1, 'sliding', n_states=len(states)).fit_fetch(dtrajs)
    msm = MaximumLikelihoodMSM(reversible=False, allow_disconnected=True).fit_fetch(cm)
    score = 0
    for i in range(len(states)):
        state = states[i]
        score += msm.transition_matrix[state, state] * (float(counts[i]) / float(len(reference)))

    return score, mismatches

def plot_mismatches(mismatches, membership_transform):
    f, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8))

    X = Xinit_test[:5000]
    Y = ds_test.data_lagged[:5000]
    mismatches = mismatches[np.where(mismatches < 5000)[0]]

    ref = membership_transform(X)
    ix_good = np.arange(len(X))
    ix_good = np.setdiff1d(ix_good, mismatches, assume_unique=True)

    ax1.scatter(*X[ix_good].T, c=ref[ix_good])
    ax1.scatter(*X[mismatches].T, c='red', s=10, zorder=500, linewidths=.5)
    ax1.set_aspect('equal')
    ax1.set_ylim([-4, 4])
    ax1.set_title(r'$t=0$')

    ax2.scatter(*Y[ix_good].T, c=ref[ix_good])
    ax2.scatter(*Y[mismatches].T, c='red', s=10, zorder=500, linewidths=.5)
    ax2.set_aspect('equal')
    ax2.set_ylim([-4, 4])
    ax2.set_title(r'$t=40$')
[11]:
score_kcca, mismatches_kcca = coherence_score(memberships_transform_kcca)
print(f"Coherence score: {score_kcca:.2f}")
plot_mismatches(mismatches_kcca, memberships_transform_kcca)
Coherence score: 0.86
../../_images/notebooks_examples_coherence-bickley-jet_19_1.png

VAMPNets

For VAMPNets we first transform the data into three-dimensional space to account for quasi-periodicity.

[12]:
import torch
import torch.nn as nn
from deeptime.decomposition.deep import VAMPNet

if torch.cuda.is_available():
    device = torch.device("cuda")
    torch.backends.cudnn.benchmark = True
else:
    device = torch.device("cpu")
torch.set_num_threads(8)

Now we define the VAMPNet lobe.

[13]:
lobe = nn.Sequential(
    nn.BatchNorm1d(3),
    nn.Linear(3, 256), nn.ELU(), nn.Dropout(),
    nn.Linear(256, 512), nn.ELU(), nn.Dropout(),
    nn.Linear(512, 128), nn.ELU(), nn.Dropout(),
    nn.Linear(128, 128), nn.ELU(), nn.Dropout(),
    nn.Linear(128, 9)
).to(device=device)

Training is performed as follows:

First, we define modified datasets. Since the data fits into GPU memory we already put it on the device and also transform it into three-dimensional space to deal with quasi-periodicity.

[14]:
# Dataset which transforms X and a slightly noisy version of Y into 3-dimensional space
class BickleyJet3DTorchDS:
    def __init__(self, ds):
        X, Y = ds[:]
        self.X = torch.tensor(BickleyJet.to_3d(X), dtype=torch.float32, device=device)
        self.Y = torch.tensor(BickleyJet.to_3d(Y), dtype=torch.float32, device=device)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, ix):
        x, y = self.X[ix], self.Y[ix]
        return x, y

ds_train_3d = BickleyJet3DTorchDS(ds_train.astype(np.float32))
ds_test_3d = BickleyJet3DTorchDS(ds_test.astype(np.float32))

Now we create loaders. We can speed up training by using batch samplers.

[15]:
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, BatchSampler, RandomSampler, SequentialSampler

collate_fn = lambda x: x[0]
train_sampler = BatchSampler(RandomSampler(ds_train_3d), batch_size=2048, drop_last=True)
loader_train = DataLoader(ds_train_3d, sampler=train_sampler, collate_fn=collate_fn)
val_sampler = BatchSampler(SequentialSampler(ds_test_3d), batch_size=2048, drop_last=False)
loader_val = DataLoader(ds_test_3d, sampler=val_sampler, collate_fn=collate_fn)

While VAMPNets possess a fit() method, we will implement the training loop ourselves to make use of checkpointing and learning rate scheduling.

[16]:
from pathlib import Path
from deeptime.decomposition.deep import vamp_score
from deeptime.util.torch import CheckpointManager
from torch.optim.lr_scheduler import ReduceLROnPlateau

checkpoints_dir = Path('.') / 'checkpoints'
checkpoints = CheckpointManager(output_dir=checkpoints_dir)
opt = torch.optim.Adam(lobe.parameters(), 3e-4)
scheduler = ReduceLROnPlateau(opt, mode='max', patience=200)

scores = [[], []]

n_epochs = 8000
for epoch in tqdm(range(n_epochs)):
    train_scores = []
    lobe.train()
    for X, Y in loader_train:
        opt.zero_grad()
        loss = -vamp_score(lobe(X), lobe(Y), mode='regularize')
        train_scores.append((-loss).item())
        loss.backward()
        opt.step()
    train_score = np.mean(train_scores)
    scores[0].append(train_score)

    lobe.eval()
    val_scores = []
    for Xval, Yval in loader_val:
        val_score = vamp_score(lobe(Xval), lobe(Yval), mode='regularize')
        val_scores.append(val_score)
    val_score = torch.mean(torch.stack(val_scores))
    scheduler.step(val_score)
    scores[1].append(val_score.item())
    checkpoints.step(epoch, val_score, models=dict(lobe=lobe))

We plot training and validation scores. Due to dropout it is expected that the validation score is higher than the training score.

[17]:
plt.semilogx(scores[0], label='training')
plt.semilogx(scores[1], label='validation')
plt.xlabel('epoch')
plt.xlabel('score')
plt.legend();
../../_images/notebooks_examples_coherence-bickley-jet_31_0.png

We load the model with the best validation score.

[18]:
from deeptime.decomposition.deep import VAMPNetModel

lobe.load_state_dict(torch.load(checkpoints_dir / 'best.ckpt')['lobe'])
vampnet = VAMPNetModel(lobe, device=device)

We can now use the trained VAMPNet model to create a VAMP model instance.

[19]:
from deeptime.decomposition import VAMP

vamp_vampnet = VAMP(
    observable_transform=lambda x: vampnet.transform(BickleyJet.to_3d(x)), dim=9, epsilon=1e-12
).fit_fetch(ds_train)

plot_singular_functions(vamp_vampnet.transform)
../../_images/notebooks_examples_coherence-bickley-jet_35_0.png

Obtain a clustering:

[20]:
cluster_vampnet = cluster(vamp_vampnet.transform(ds_test.data))
memberships_transform_vampnet = lambda x: cluster_vampnet.transform(vamp_vampnet.transform(x))

memberships = memberships_transform_vampnet(ds_test.data)

f, ax = plt.subplots(1, 1)
ax.scatter(*ds_test.data.T, c=memberships)
ax.set_aspect('equal')
../../_images/notebooks_examples_coherence-bickley-jet_37_0.png

and approximate the coherence score.

[21]:
score_vampnet, mismatches_vampnet = coherence_score(memberships_transform_vampnet)
print(f"Coherence score: {score_vampnet:.2f}")
plot_mismatches(mismatches_vampnet, memberships_transform_vampnet)
Coherence score: 0.76
../../_images/notebooks_examples_coherence-bickley-jet_39_1.png

VAMP

Here we fit a VAMP model with a set of randomized feature functions.

[22]:
class ChiRnd:

    def __init__(self, n=100, fan_in=3, scale=1, bias_var=5, out_dim=50):
        self.n_basis = n
        self.out_dim = out_dim
        self.W = np.random.normal(scale=scale, size=(fan_in, self.n_basis))
        self.b = np.random.uniform(-bias_var, bias_var, size=(self.n_basis,))

        self.W2 = np.random.normal(scale=scale, size=(self.n_basis, out_dim))
        self.b2 = np.random.uniform(-bias_var, bias_var, size=(out_dim,))

    def nonlinearity(self, x):
        return np.exp(-x*x)

    def __call__(self, x):
        return self.nonlinearity(BickleyJet.to_3d(x) @ self.W + self.b) @ self.W2 + self.b2

chi = ChiRnd()
[23]:
vamp = VAMP(lagtime=1, dim=9, observable_transform=chi).fit_fetch(ds_train)
plot_singular_functions(vamp.transform)
../../_images/notebooks_examples_coherence-bickley-jet_42_0.png
[24]:
cluster_vamp = cluster(vamp.transform(ds_test.data))

memberships_transform_vamp = lambda x: cluster_vamp.transform(vamp.transform(x))
memberships = memberships_transform_vamp(ds_test.data)

f, ax = plt.subplots(1, 1)
ax.scatter(*ds_test.data.T, c=memberships)
ax.set_aspect('equal')
../../_images/notebooks_examples_coherence-bickley-jet_43_0.png
[25]:
score_vamp, mismatches_vamp = coherence_score(memberships_transform_vamp)
print(f"Coherence score: {score_vamp:.2f}")
plot_mismatches(mismatches_vamp, memberships_transform_vamp)
Coherence score: 0.74
../../_images/notebooks_examples_coherence-bickley-jet_44_1.png

KVAD

[26]:
from deeptime.decomposition import KVAD

kvad = KVAD(kernel=GaussianKernel(1.),
            lagtime=1, epsilon=1e-5, dim=9,
            observable_transform=ChiRnd()).fit_fetch(ds_train)

plot_singular_functions(kvad.transform)
../../_images/notebooks_examples_coherence-bickley-jet_46_0.png
[27]:
cluster_kvad = cluster(kvad.transform(ds_test.data))

memberships_transform_kvad = lambda x: cluster_kvad.transform(kvad.transform(x))
memberships = memberships_transform_kvad(ds_test.data)

f, ax = plt.subplots(1, 1)
ax.scatter(*ds_test.data.T, c=memberships)
ax.set_aspect('equal')
../../_images/notebooks_examples_coherence-bickley-jet_47_0.png
[28]:
score_kvad, mismatches_kvad = coherence_score(memberships_transform_kvad)
print(f"Coherence score: {score_kvad:.2f}")
plot_mismatches(mismatches_kvad, memberships_transform_kvad)
Coherence score: 0.73
../../_images/notebooks_examples_coherence-bickley-jet_48_1.png

KVADNets

Similar to VAMPNets we optimize the KVAD score with a neural network to find a KVAD-score optimal parametrization.

[29]:
from deeptime.decomposition.deep import kvad_score
from deeptime.kernels import TorchGaussianKernel
[30]:
kernel = TorchGaussianKernel(.5)

kvadnets_lobe = nn.Sequential(
    nn.BatchNorm1d(3),
    nn.Linear(3, 256), nn.ELU(), nn.Dropout(),
    nn.Linear(256, 512), nn.ELU(), nn.Dropout(),
    nn.Linear(512, 128), nn.ELU(), nn.Dropout(),
    nn.Linear(128, 128), nn.ELU(), nn.Dropout(),
    nn.Linear(128, 9)
).to(device=device)

from tqdm.notebook import tqdm
from torch.utils.data import DataLoader, BatchSampler, RandomSampler, SequentialSampler

collate_fn = lambda x: x[0]
train_sampler = BatchSampler(RandomSampler(ds_train_3d), batch_size=512, drop_last=True)
loader_train = DataLoader(ds_train_3d, sampler=train_sampler, collate_fn=collate_fn)
val_sampler = BatchSampler(SequentialSampler(ds_test_3d), batch_size=512, drop_last=False)
loader_val = DataLoader(ds_test_3d, sampler=val_sampler, collate_fn=collate_fn)

opt = torch.optim.Adam(kvadnets_lobe.parameters(), 3e-4)

scores = [[], []]

n_epochs = 15000
for epoch in tqdm(range(n_epochs)):
    epoch_train_scores = []
    kvadnets_lobe.train()
    for batch_0, batch_t, in loader_train:
        opt.zero_grad()
        chi_0 = kvadnets_lobe(batch_0)
        loss = -kvad_score(chi_0, batch_t, kernel=kernel)
        loss.backward()
        opt.step()
        epoch_train_scores.append((-loss).item())
    scores[0].append(np.mean(epoch_train_scores))

    epoch_val_scores = []
    kvadnets_lobe.eval()
    for batch_0, batch_t, in loader_val:
        chi_0 = kvadnets_lobe(batch_0)
        val_score = kvad_score(chi_0, batch_t, kernel=kernel)
        epoch_val_scores.append(val_score.item())
    scores[1].append(np.mean(epoch_val_scores))

kvadnets_lobe.eval()

# definition of the estimated transformation
def chi_kvadnet(X):
    X = BickleyJet.to_3d(X)
    kvadnets_lobe.eval()
    with torch.no_grad():
        chi_X = kvadnets_lobe(torch.from_numpy(X.astype(np.float32)).to(device=device))
    return chi_X.cpu().numpy()
[31]:
plt.semilogx(scores[0], label='training')
plt.semilogx(scores[1], label='validation')
plt.xlabel('epoch')
plt.xlabel('score')
plt.legend();
../../_images/notebooks_examples_coherence-bickley-jet_52_0.png
[32]:
kvadnet = KVAD(
    kernel=kernel, observable_transform=chi_kvadnet,
    dim=9
).fit_fetch(ds_test)

plot_singular_functions(kvadnet.transform)
../../_images/notebooks_examples_coherence-bickley-jet_53_0.png
[33]:
cluster_kvadnet = cluster(kvadnet.transform(ds_test.data))

memberships_transform_kvadnet = lambda x: cluster_kvadnet.transform(kvadnet.transform(x))
memberships = memberships_transform_kvadnet(ds_test.data)

f, ax = plt.subplots(1, 1)
ax.scatter(*ds_test.data.T, c=memberships)
ax.set_aspect('equal')
../../_images/notebooks_examples_coherence-bickley-jet_54_0.png
[34]:
score_kvadnet, mismatches_kvadnet = coherence_score(memberships_transform_kvadnet)
print(f"Coherence score: {score_kvadnet:.2f}")
plot_mismatches(mismatches_kvadnet, memberships_transform_kvadnet)
Coherence score: 0.88
../../_images/notebooks_examples_coherence-bickley-jet_55_1.png

Comparison of scores

We obtain the following:

[40]:
def vamp2_score(transform):
    return VAMP(observable_transform=transform, dim=9).fit_fetch(ds_test).score(2)

def kvad_score(transform):
    return KVAD(GaussianKernel(0.5), observable_transform=transform, dim=9).fit_fetch(ds_test).score
[47]:
import tabulate
print(tabulate.tabulate(
    [
        ["KVADNets", vamp2_score(kvadnet.transform), kvad_score(kvadnet.transform), score_kvadnet],
        ["Kernel CCA", vamp2_score(lambda x: kcca.transform(x).real), kvad_score(lambda x: kcca.transform(x).real), score_kcca],
        ["VAMPNets", vamp2_score(vamp_vampnet.transform), kvad_score(vamp_vampnet.transform), score_vampnet],
        ["VAMP", vamp2_score(vamp.transform), kvad_score(vamp.transform), score_vamp],
        ["KVAD", vamp2_score(kvad.transform), kvad_score(kvad.transform), score_kvad],
    ],
    headers=["Method", "VAMP-2 score", "KVAD score", "Coherence score"]
))
Method        VAMP-2 score    KVAD score    Coherence score
----------  --------------  ------------  -----------------
KVADNets           5.88749     0.0796252             0.8802
Kernel CCA         5.8305      0.0762447             0.8598
VAMPNets           7.07734     0.0749358             0.7622
VAMP               4.75647     0.0677209             0.738
KVAD               4.6162      0.0667439             0.7298