A deep dive on K-Means where smart initialization and the full algorithm is implemented from scratch using pytorch
import math, random, matplotlib.pyplot as plt, operator, torch
from functools import partial
from fastcore.all import *
from torch.distributions.multivariate_normal import MultivariateNormal
from torch import tensor
plt.style.use('dark_background')
torch.manual_seed(42)
torch.set_printoptions(precision=3, linewidth=140, sci_mode=False)
def plot_data(centroids:torch.Tensor,# Centroid coordinates
data:torch.Tensor, # Data Coordinates
n_samples:int, # Number of samples
ax:plt.Axes=None # Matplotlib Axes object
)-> None:
'''Creates a visualization of centroids and data points for clustering problems'''
if ax is None: _,ax = plt.subplots()
for i, centroid in enumerate(centroids):
samples = data[i*n_samples:(i+1)*n_samples]
ax.scatter(samples[:,0], samples[:,1], s=1)
ax.plot(*centroid, markersize=10, marker="x", color='k', mew=5)
ax.plot(*centroid, markersize=5, marker="x", color='m', mew=2)
We need to create a dataset. This data generation step follows what Jeremy Howard did in a notebook he did on meanshift clustering, which is a different clustering algorithm. That notebook was part of the fast.ai 2022 part 2 course.
Since the same dataset can be used, I used his and removed some unneeded print statements. See the plot for what the data looks like.
n_clusters = 6
n_samples = 250
centroids = torch.rand(n_clusters, 2)*70-35 # Points between -35 and 35
def sample(m): return MultivariateNormal(m, torch.diag(tensor([5.,5.]))).sample((n_samples,))
data = torch.cat([sample(c) for c in centroids])
plot_data(centroids, data, n_samples)
K-means is a clustering algorithm. There's 4 main steps to the process:
+ Initialize Centroids at smart starting positions
+ Calculate distance between data points and centroids
+ Classify data points based on closes centroid
+ Update centroids by moving them toward the mean of its points
Once you have those steps, you can repeat the last 3 until your centroids no longer move.
In order to initialize our centroids we need to be able to calculate distances, so let's do that first.
Given a tensor of centroid coordinates and a tensor of data coordinates we calculate distance by:
That gives us the euclidean distance between each data point and each centroid.
def calculate_distances(centroids:torch.Tensor, # Centroid coordinates
data:torch.Tensor # Data points you want to cluster
)-> torch.Tensor: # Tensor containing euclidean distance between each centroid and data point
'''Calculate distance between centroids and each datapoint'''
axis_distances = data.reshape(-1,1,2).sub(centroids.reshape(1,-1,2)).abs()
euclid_distances = axis_distances.square().sum(axis=-1).sqrt()
return euclid_distances
Where we initialize our centroids is really important. If we don't have good initialization we are very likely to get stuck in a local optimum. Especially with 6 centroids. One option is to run the algorithm many times and pick the best solution, but it's a much better idea to try to have good initializations.
We pick centroid locations in the following way:
This ensures we get initialization that are nice and far away from each other and spread out amonth the data, minimizing the risk of hitting local optimums.
def initialize_centroids(data:torch.Tensor,# Data points you want to cluster
k:torch.Tensor # Number of centroids you want to initialize
)->torch.Tensor: # Returns starting centroid coordinates
'''Initialize starting points for centroids as far from each other as possible.'''
pred_centroids = data[random.sample(range(0,len(data)),1)]
for i in range(k-1):
_centroid = data[calculate_distances(pred_centroids,data).min(axis=1).values.argmax()]
pred_centroids = torch.stack([*pred_centroids,_centroid])
return pred_centroids
Once we have centroids (or updated centroids), we need to assign a centroid to each data point. We do this by calculating the distance between each data point and each centroid, and assigning each datapoint to it's closes centroid.
def assign_centroids(centroids:torch.Tensor, # Centroid coordinates
data:torch.Tensor # Data points you want to cluster
)->torch.Tensor: # Tensor containing new centroid assignments for each data point
'''Based on distances update centroid assignments'''
euclid_distances = calculate_distances(centroids,data)
assigned_cluster = euclid_distances.squeeze().argmin(axis=1)
return assigned_cluster
To update the centroid locations, we take the mean of all the data point assigned to that centroid. We make the new centroid that point.
def update_centroids(centroid_assignments:torch.Tensor, # Centroid coordinates
data:torch.Tensor # Data points you want to cluster
)->torch.Tensor: # Tensor containing updated centroid coodinates
'''Update centroid locations'''
n_centroids = len(centroid_assignments.unique())
pred_centroids = [data[centroid_assignments==i].mean(axis=0) for i in range(n_centroids)]
return torch.stack(pred_centroids)
Here we put it all together and train our K-means model. As you can see it fits this dataset very quickly (it's a simple dataset).
pred_centroids = initialize_centroids(data,n_clusters)
for epoch in range(3):
plot_data(pred_centroids, data, n_samples)
centroid_assignments = assign_centroids(pred_centroids,data)
pred_centroids = update_centroids(centroid_assignments,data)