Extracting Dominant Colours from an Image using K-means Clustering from Scratch

Tanvi Penumudy
Analytics Vidhya
Published in
4 min readJan 16, 2021

--

Extract the dominant colours from any image of your choice in less than 5 minutes from scratch!

You’ve probably heard the phrase “a picture is worth a thousand words.” In our digitally-advanced age, this is more accurate than ever; a lot of information can be extracted from an image. High-level computer vision systems have allowed self-driving cars to recognize whether an object is a pedestrian crossing or a static road hazard up ahead, and Instagram filters are face-detecting and interactive. These advancements stem from most of the fundamental approaches of machine learning.

For more on Self-Driving Cars: A Beginner’s Guide to Reinforcement Learning and its Basic Implementation from Scratch

Machine learning involves the learning process machines undertake in order to understand data and provide some answers about the data. In the context of image processing, an application of machine learning could be the attempt to process an image digitally, with numbers that represent the pixels and colours as data.

For more on Machine Learning: A Beginner’s Guide for Getting Started with Machine Learning

Approaches that don’t provide prediction or assume a correct set of outputs but instead uncover insights from a given dataset are referred to as unsupervised. One such technique for image processing and information extraction is K-means clustering, a learning approach that aims to partition n data points into k groups.

For the conceptual overview of K-means Clustering, refer —Everything you need to know about K-Means Clustering

We shall now begin by the code walkthrough for the implementation of the K-means Clustering algorithm from scratch:

Fret not! I promise you that it’s going to turn out as fascinating as it sounds!

import numpy as np
import matplotlib.pyplot as plt

np.random.seed(42)

def euclidean_distance(x1, x2):
return np.sqrt(np.sum((x1 - x2)**2))

class KMeans():

def __init__(self, K=5, max_iters=100, plot_steps=False):
self.K = K
self.max_iters = max_iters
self.plot_steps = plot_steps

# list of sample indices for each cluster
self.clusters = [[] for _ in range(self.K)]
# the centers (mean feature vector) for each cluster
self.centroids = []

def predict(self, X):
self.X = X
self.n_samples, self.n_features = X.shape

# initialize
random_sample_idxs = np.random.choice(self.n_samples, self.K, replace=False)
self.centroids = [self.X[idx] for idx in random_sample_idxs]

# Optimize clusters
for _ in range(self.max_iters):
# Assign samples to closest centroids (create clusters)
self.clusters = self._create_clusters(self.centroids)
if self.plot_steps:
self.plot()

# Calculate new centroids from the clusters
centroids_old = self.centroids
self.centroids = self._get_centroids(self.clusters)

# check if clusters have changed
if self._is_converged(centroids_old, self.centroids):
break

if self.plot_steps:
self.plot()

# Classify samples as the index of their clusters
return self._get_cluster_labels(self.clusters)


def _get_cluster_labels(self, clusters):
# each sample will get the label of the cluster it was assigned to
labels = np.empty(self.n_samples)

for cluster_idx, cluster in enumerate(clusters):
for sample_index in cluster:
labels[sample_index] = cluster_idx
return labels

def _create_clusters(self, centroids):
# Assign the samples to the closest centroids to create clusters
clusters = [[] for _ in range(self.K)]
for idx, sample in enumerate(self.X):
centroid_idx = self._closest_centroid(sample, centroids)
clusters[centroid_idx].append(idx)
return clusters

def _closest_centroid(self, sample, centroids):
# distance of the current sample to each centroid
distances = [euclidean_distance(sample, point) for point in centroids]
closest_index = np.argmin(distances)
return closest_index

def _get_centroids(self, clusters):
# assign mean value of clusters to centroids
centroids = np.zeros((self.K, self.n_features))
for cluster_idx, cluster in enumerate(clusters):
cluster_mean = np.mean(self.X[cluster], axis=0)
centroids[cluster_idx] = cluster_mean
return centroids

def _is_converged(self, centroids_old, centroids):
# distances between each old and new centroids, fol all centroids
distances = [euclidean_distance(centroids_old[i], centroids[i]) for i in range(self.K)]
return sum(distances) == 0

def plot(self):
fig, ax = plt.subplots(figsize=(12, 8))

for i, index in enumerate(self.clusters):
point = self.X[index].T
ax.scatter(*point)

for point in self.centroids:
ax.scatter(*point, marker="x", color='black', linewidth=2)

plt.show()
def cent(self):
return self.centroids
#Extracting Dominant Colours in an Imageimport cv2
from skimage import io
from google.colab.patches import cv2_imshow
url = "https://www.teahub.io/photos/full/35-355143_windows-10-wallpaper-umbrella.jpg"
img = io.imread(url)
img.shapeOut:
(1080, 1920, 3)
img_init = img.copy()plt.figure(figsize=(6, 6))
plt.imshow(img_init)
Out:
<matplotlib.image.AxesImage at 0x7f49e6a7d6a0>
img = img.reshape((img.shape[0] * img.shape[1],img.shape[2]))k = KMeans(K=5) #for 5-most dominant colours
y_pred = k.predict(img)
k.cent()
Out:
array([[ 53.16708662, 93.69632655, 175.47967713],
[133.07051658, 195.54817432, 47.66459003],
[237.62760178, 76.63096981, 21.42656026],
[248.665852 , 31.23121874, 121.13346739],
[206.22142881, 229.36967717, 152.23724866]])
y_predOut:
array([2., 2., 2., ..., 0., 0., 0.])
label_indx = np.arange(0,len(np.unique(y_pred)) + 1)
label_indx
Out:
array([0, 1, 2, 3, 4, 5])
np.histogram(y_pred, bins = label_indx)Out:
(array([565545, 172073, 559377, 593291, 183314]), array([0, 1, 2, 3, 4, 5]))
(hist, _) = np.histogram(y_pred, bins = label_indx)
hist = hist.astype("float")
hist /= hist.sum()
hist
Out:
array([0.27273582, 0.08298274, 0.26976128, 0.28611642, 0.08840374])
hist_bar = np.zeros((50, 300, 3), dtype = "uint8")startX = 0
for (percent, color) in zip(hist, k.cent()):
endX = startX + (percent * 300) # to match grid
cv2.rectangle(hist_bar, (int(startX), 0), (int(endX), 50),
color.astype("uint8").tolist(), -1)
startX = endX
plt.figure(figsize=(15,15))
plt.subplot(121)
plt.imshow(img_init)
plt.subplot(122)
plt.imshow(hist_bar)
plt.show()

Hope you enjoyed and made the most out of this article! Stay tuned for my upcoming blogs! Make sure to CLAP and FOLLOW if you find my content helpful/informative!

For complete code implementation:

--

--