# Broadcasting in Python: K-means algorithm written by [@marc_lelarge](https://twitter.com/marc_lelarge) :::info **Cost function:** for a partition $C_1, \dots, C_k$ of $n$ observations in $k$ clusters, the quality of the clustering is given by the following cost function: \begin{eqnarray*} \sum_{j=1}^k \sum_{i\in C_j} \|x_i-\mu_j\|^2 \text{ where, } \mu_j =\frac{1}{|C_j|}\sum_{i\in C_j} x_i. \end{eqnarray*} ::: ![](https://i.imgur.com/sq6PEbl.png) :::success **k-means algorithm:** - take $k$ random centers - (a) when centers are fixed, update the partition by allocating each data sample to its nearest center - (b) when the partition is fixed, update the centers according to the formula: $\mu_j =\frac{1}{|C_j|}\sum_{i\in C_j} x_i$ - iterate steps (a) and (b) ::: **Python code:** To focus on [brodcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) in the partition step (a), we run the algorithm for a fixed number of iterations (instead of checking convergence criteria). ```python= import numpy as np class Kmeans: def __init__(self, k, num_iter=50): self.k = k self.num_iter = num_iter def partition_step(self): diff = self.data[:, None] - self.centroids[None] # (n, k, d) distances = np.einsum('nkd,nkd->nk', diff, diff) # (n, k) self.labels = np.argmin(distances, axis=1) # (n,) def centers_step(self): self.centroids = np.stack([self.data[self.labels==i].mean(axis=0) \ for i in range(self.k)]) def fit(self, data): n, d = data.shape self.centroids = data[np.random.choice(n, self.k, \ replace=False)] # (k, d) self.labels = np.empty(n) # (n,) self.data = data for _ in range(self.num_iter): self.partition_step() self.centers_step() return self.centroids, self.labels ``` This code is 60x faster than a naive implementation with a for loop see [kmeans_broadcast.ipynb](https://gist.github.com/mlelarge/8972b3cda4e1aad67fdecab877065adf) **Applications** Below, we run k-means algorithm on all the colors of all pixels and then replace the original color (on the left) by its corresponding centers (code available in [this repo](https://github.com/mlelarge/agreg/blob/main/k-means.ipynb)): ![](https://i.imgur.com/270N5Ph.png) ###### tags: `public` `python` `machine learning`