# Broadcasting in Python: K-means algorithm
written by [@marc_lelarge](https://twitter.com/marc_lelarge)
:::info
**Cost function:** for a partition $C_1, \dots, C_k$ of $n$ observations in $k$ clusters, the quality of the clustering is given by the following cost function:
\begin{eqnarray*}
\sum_{j=1}^k \sum_{i\in C_j} \|x_i-\mu_j\|^2 \text{ where, } \mu_j =\frac{1}{|C_j|}\sum_{i\in C_j} x_i.
\end{eqnarray*}
:::

:::success
**k-means algorithm:**
- take $k$ random centers
- (a) when centers are fixed, update the partition by allocating each data sample to its nearest center
- (b) when the partition is fixed, update the centers according to the formula: $\mu_j =\frac{1}{|C_j|}\sum_{i\in C_j} x_i$
- iterate steps (a) and (b)
:::
**Python code:**
To focus on [brodcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) in the partition step (a), we run the algorithm for a fixed number of iterations (instead of checking convergence criteria).
```python=
import numpy as np
class Kmeans:
def __init__(self, k, num_iter=50):
self.k = k
self.num_iter = num_iter
def partition_step(self):
diff = self.data[:, None] - self.centroids[None] # (n, k, d)
distances = np.einsum('nkd,nkd->nk', diff, diff) # (n, k)
self.labels = np.argmin(distances, axis=1) # (n,)
def centers_step(self):
self.centroids = np.stack([self.data[self.labels==i].mean(axis=0) \
for i in range(self.k)])
def fit(self, data):
n, d = data.shape
self.centroids = data[np.random.choice(n, self.k, \
replace=False)] # (k, d)
self.labels = np.empty(n) # (n,)
self.data = data
for _ in range(self.num_iter):
self.partition_step()
self.centers_step()
return self.centroids, self.labels
```
This code is 60x faster than a naive implementation with a for loop see [kmeans_broadcast.ipynb](https://gist.github.com/mlelarge/8972b3cda4e1aad67fdecab877065adf)
**Applications**
Below, we run k-means algorithm on all the colors of all pixels and then replace the original color (on the left) by its corresponding centers (code available in [this repo](https://github.com/mlelarge/agreg/blob/main/k-means.ipynb)):

###### tags: `public` `python` `machine learning`