Introduction
Clustering is a fundamental technique in the field of unsupervised machine learning, where the goal is to group similar data points together based on certain features. K-means is one of the most popular and widely used clustering algorithms due to its simplicity and efficiency. In this blog, we will explore K-means from scratch, step by step, to gain a deeper understanding of its working principles.
What is K-means?
K-means is an iterative algorithm that partitions data into K clusters based on their similarity to a centroid. Each cluster is represented by the mean (centroid) of the data points assigned to it. The "K" in K-means represents the number of clusters we want to identify in the data.
Algorithm Steps
a. Initialization:
Randomly select K data points as the initial centroids.
Assign each data point to the nearest centroid to create K clusters.
b. Update Centroids:
Calculate the mean of each cluster's data points to find the new centroids.
Repeat the process until the centroids converge (i.e., their positions stabilize).
c. Convergence:
- Check if the centroids have changed. If not, the algorithm has converged, and we can stop.
Implementing K-means from Scratch in Python
Let's now implement the K-means algorithm from scratch in Python. We'll use the NumPy library for numerical computations.
Step 1: Import the necessary libraries
import numpy as np
import matplotlib.pyplot as plt
Step 2: Define the K-means function
def kmeans(data, k, max_iterations=100):
# Step 2a: Initialize centroids randomly
centroids = data[np.random.choice(data.shape[0], k, replace=False)]
for _ in range(max_iterations):
# Step 2b: Assign data points to the nearest centroid
labels = np.argmin(np.linalg.norm(data[:, np.newaxis] - centroids, axis=2), axis=1)
# Step 2c: Update centroids
new_centroids = np.array([data[labels == i].mean(axis=0) for i in range(k)])
# Step 2d: Check for convergence
if np.allclose(centroids, new_centroids):
break
centroids = new_centroids
return centroids, labels
Step 3: Generate sample data and run K-means
# Generate random data
np.random.seed(42)
data = np.random.randn(100, 2)
# Set the number of clusters (K)
k = 3
# Run K-means on the data
centroids, labels = kmeans(data, k)
Step 4: Visualize the results
# Scatter plot the data points with different colors for each cluster
plt.scatter(data[:, 0], data[:, 1], c=labels)
# Plot the centroids as 'X'
plt.scatter(centroids[:, 0], centroids[:, 1], marker='X', s=100, c='red')
plt.title('K-means Clustering')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()
Conclusion
In this blog, we delved into the basics of the K-means clustering algorithm and implemented it from scratch in Python. By understanding the iterative steps of initialization, updating centroids, and convergence, we gain valuable insights into how the algorithm works. K-means is a versatile tool used in various domains like customer segmentation, image compression, and anomaly detection. As you continue your journey in machine learning, having a strong grasp of the fundamentals will serve as a solid foundation for tackling more complex clustering challenges. Happy clustering!