Extracting Dominant Colours from an Image using K-means Clustering from Scratch
Extract the dominant colours from any image of your choice in less than 5 minutes from scratch!
You’ve probably heard the phrase “a picture is worth a thousand words.” In our digitally-advanced age, this is more accurate than ever; a lot of information can be extracted from an image. High-level computer vision systems have allowed self-driving cars to recognize whether an object is a pedestrian crossing or a static road hazard up ahead, and Instagram filters are face-detecting and interactive. These advancements stem from most of the fundamental approaches of machine learning.
For more on Self-Driving Cars: A Beginner’s Guide to Reinforcement Learning and its Basic Implementation from Scratch
Machine learning involves the learning process machines undertake in order to understand data and provide some answers about the data. In the context of image processing, an application of machine learning could be the attempt to process an image digitally, with numbers that represent the pixels and colours as data.
For more on Machine Learning: A Beginner’s Guide for Getting Started with Machine Learning
Approaches that don’t provide prediction or assume a correct set of outputs but instead uncover insights from a given dataset are referred to as unsupervised. One such technique for image processing and information extraction is K-means clustering, a learning approach that aims to partition n data points into k groups.
For the conceptual overview of K-means Clustering, refer —Everything you need to know about K-Means Clustering
We shall now begin by the code walkthrough for the implementation of the K-means Clustering algorithm from scratch:
Fret not! I promise you that it’s going to turn out as fascinating as it sounds!
import numpy as np
import matplotlib.pyplot as plt
np.random.seed(42)
def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1 - x2)**2))
class KMeans():
def __init__(self, K=5, max_iters=100, plot_steps=False):
self.K = K
self.max_iters = max_iters
self.plot_steps = plot_steps
# list of sample indices for each cluster
self.clusters = [[] for _ in range(self.K)]
# the centers (mean feature vector) for each cluster
self.centroids = []
def predict(self, X):
self.X = X
self.n_samples, self.n_features = X.shape
# initialize
random_sample_idxs = np.random.choice(self.n_samples, self.K, replace=False)
self.centroids = [self.X[idx] for idx in random_sample_idxs]
# Optimize clusters
for _ in range(self.max_iters):
# Assign samples to closest centroids (create clusters)
self.clusters = self._create_clusters(self.centroids)
if self.plot_steps:
self.plot()
# Calculate new centroids from the clusters
centroids_old = self.centroids
self.centroids = self._get_centroids(self.clusters)
# check if clusters have changed
if self._is_converged(centroids_old, self.centroids):
break
if self.plot_steps:
self.plot()
# Classify samples as the index of their clusters
return self._get_cluster_labels(self.clusters)
def _get_cluster_labels(self, clusters):
# each sample will get the label of the cluster it was assigned to
labels = np.empty(self.n_samples)
for cluster_idx, cluster in enumerate(clusters):
for sample_index in cluster:
labels[sample_index] = cluster_idx
return labels
def _create_clusters(self, centroids):
# Assign the samples to the closest centroids to create clusters
clusters = [[] for _ in range(self.K)]
for idx, sample in enumerate(self.X):
centroid_idx = self._closest_centroid(sample, centroids)
clusters[centroid_idx].append(idx)
return clusters
def _closest_centroid(self, sample, centroids):
# distance of the current sample to each centroid
distances = [euclidean_distance(sample, point) for point in centroids]
closest_index = np.argmin(distances)
return closest_index
def _get_centroids(self, clusters):
# assign mean value of clusters to centroids
centroids = np.zeros((self.K, self.n_features))
for cluster_idx, cluster in enumerate(clusters):
cluster_mean = np.mean(self.X[cluster], axis=0)
centroids[cluster_idx] = cluster_mean
return centroids
def _is_converged(self, centroids_old, centroids):
# distances between each old and new centroids, fol all centroids
distances = [euclidean_distance(centroids_old[i], centroids[i]) for i in range(self.K)]
return sum(distances) == 0
def plot(self):
fig, ax = plt.subplots(figsize=(12, 8))
for i, index in enumerate(self.clusters):
point = self.X[index].T
ax.scatter(*point)
for point in self.centroids:
ax.scatter(*point, marker="x", color='black', linewidth=2)
plt.show()
def cent(self):
return self.centroids#Extracting Dominant Colours in an Imageimport cv2
from skimage import io
from google.colab.patches import cv2_imshowurl = "https://www.teahub.io/photos/full/35-355143_windows-10-wallpaper-umbrella.jpg"
img = io.imread(url)img.shapeOut:
(1080, 1920, 3)img_init = img.copy()plt.figure(figsize=(6, 6))
plt.imshow(img_init)Out:
<matplotlib.image.AxesImage at 0x7f49e6a7d6a0>
img = img.reshape((img.shape[0] * img.shape[1],img.shape[2]))k = KMeans(K=5) #for 5-most dominant colours
y_pred = k.predict(img)
k.cent()Out:
array([[ 53.16708662, 93.69632655, 175.47967713],
[133.07051658, 195.54817432, 47.66459003],
[237.62760178, 76.63096981, 21.42656026],
[248.665852 , 31.23121874, 121.13346739],
[206.22142881, 229.36967717, 152.23724866]])y_predOut:
array([2., 2., 2., ..., 0., 0., 0.])label_indx = np.arange(0,len(np.unique(y_pred)) + 1)
label_indxOut:
array([0, 1, 2, 3, 4, 5])np.histogram(y_pred, bins = label_indx)Out:
(array([565545, 172073, 559377, 593291, 183314]), array([0, 1, 2, 3, 4, 5]))(hist, _) = np.histogram(y_pred, bins = label_indx)
hist = hist.astype("float")
hist /= hist.sum()
histOut:
array([0.27273582, 0.08298274, 0.26976128, 0.28611642, 0.08840374])hist_bar = np.zeros((50, 300, 3), dtype = "uint8")startX = 0
for (percent, color) in zip(hist, k.cent()):
endX = startX + (percent * 300) # to match grid
cv2.rectangle(hist_bar, (int(startX), 0), (int(endX), 50),
color.astype("uint8").tolist(), -1)
startX = endXplt.figure(figsize=(15,15))
plt.subplot(121)
plt.imshow(img_init)
plt.subplot(122)
plt.imshow(hist_bar)
plt.show()
Hope you enjoyed and made the most out of this article! Stay tuned for my upcoming blogs! Make sure to CLAP and FOLLOW if you find my content helpful/informative!
For complete code implementation: