Clustering methods comparison

This example demonstrates the implemented clustering methods on one example dataset for a direct comparison.

k-Means with uniform initialization: cluster centers, k-Means with uniform initialization: assignments, k-Means with k-means++ initialization: cluster centers, k-Means with k-means++ initialization: assignments, Regular space clustering: cluster centers, Regular space clustering: assignments
 8 import matplotlib.pyplot as plt  # matplotlib for plotting
 9 import numpy as np
10
11 from sklearn.mixture import GaussianMixture  # for example data generation
12
13 from deeptime.clustering import KMeans, RegularSpace
14
15 rnd = np.random.RandomState(seed=17)  # create a random state with fixed seed for reproducibility
16
17 n_components = 5
18 gmm = GaussianMixture(n_components=n_components, random_state=rnd, covariance_type='diag')  # create a GMM object
19
20 gmm.weights_ = np.array([10., 100., 10000., 10., 10.])
21 gmm.weights_ /= gmm.weights_.sum()  # weights need to form a probability distribution
22 gmm.means_ = rnd.uniform(low=-20., high=20., size=(n_components, 2))  # centers are random uniform
23 gmm.covariances_ = rnd.uniform(low=15., high=18., size=(n_components, 2))  # same for covariance matrices
24
25 samples, labels = gmm.sample(50000)  # generate data
26
27 estimators = [
28     ('k-Means with uniform initialization', KMeans(
29         n_clusters=100,  # place 100 cluster centers
30         init_strategy='uniform',  # uniform initialization strategy
31         fixed_seed=13,
32         n_jobs=8)
33      ),
34     ('k-Means with k-means++ initialization', KMeans(
35         n_clusters=100,  # place 100 cluster centers
36         init_strategy='kmeans++',  # uniform initialization strategy
37         fixed_seed=13,
38         n_jobs=8)
39      ),
40     ('Regular space clustering', RegularSpace(
41         dmin=3,  # minimum distance between cluster centers
42         max_centers=300,  # maximum number of cluster centers
43         n_jobs=8)
44      )
45 ]
46
47 f, axes = plt.subplots(3, 2, figsize=(15, 15))
48
49 for i, (label, estimator) in enumerate(estimators):
50     clustering = estimator.fit(samples).fetch_model()
51     ax1 = axes[i][0]
52     ax2 = axes[i][1]
53
54     ax1.hexbin(*samples.T, bins='log')
55     ax1.scatter(*clustering.cluster_centers.T, marker='o', c='m')
56     ax1.axis('off')
57     ax1.set_title(label + ': cluster centers')
58
59     ax2.scatter(*samples[::10].T, c=clustering.transform(samples)[::10])
60     ax2.axis('off')
61     ax2.set_title(label + ': assignments')

Total running time of the script: ( 0 minutes 8.261 seconds)

Estimated memory usage: 13 MB