Main Shift Clustering in Machine Learning with scikit-learn

Main Shift Clustering header

The concept of “main shift” with clustering in machine learning refers to finding the main or dominant change in data through cluster analysis. In essence, main shift indicates the predominant direction or phenomenon in the data, revealed through clustering. When you apply clustering to data, you look for groups or clusters of data points that share similar characteristics. By identifying the main shift, you try to understand which cluster or group represents the main or dominant change in the data. This can be useful for understanding changes in data behaviors over time, spotting anomalies, or identifying significant trends.

Il Main Shift

Mean Shift is a relatively newer clustering algorithm than some of its predecessors, such as K-Means or DBSCAN. It was first introduced in 1992 by Dorin Comaniciu and Peter Meer in their paper titled “Mean Shift: A Robust Approach Toward Feature Space Analysis”.

The idea behind Mean Shift goes back to the concept of “modes” or “peaks” in a data set, which refers to the points where the data density is maximum. The goal of the Mean Shift algorithm is to identify these modes or peaks in the data and group together the points that belong to each mode.

The name “Mean Shift” comes from the fact that the algorithm works by iteratively shifting points in the dataset towards the direction of maximum density growth, i.e. towards the closest “mode”. This process of “shifting” the points towards the maximum density continues until the points converge towards the cluster modes.

In the years since its introduction, Mean Shift has become popular for its ability to identify clusters of arbitrary shape and variable size without the need to specify the number of clusters a priori. It has been successfully used in several fields, including computer vision, data analysis, and pattern recognition.

Defining the Kernel: We start by defining a kernel function, usually a Gaussian, that measures the density of points in the data. An example of a Gaussian kernel is given by:

 K(x) = \frac{1}{{\sqrt{2\pi}\sigma}} e^{-\frac{x^2}{2\sigma^2}}

Where  \sigma is the standard deviation parameter.

Calculation of Density Estimate: Using the kernel, we can calculate the density estimate of the data at each point. This can be expressed as:

 f(x) = \frac{1}{nh^d} \sum_{i=1}^{n} K\left(\frac{x-x_i}{h}\right)

Where:

 n is the number of data points.

 h is the bandwidth window parameter.

 d is the size of the space.

 x_i are the data points.

Calculating the Mean Shift Vector: For each point in the dataset, we calculate a vector that points towards the area of maximum density. This is achieved by calculating the gradient descent of the density estimate:

 m(x) = \frac{\sum_{i=1}^{n} K'\left(\frac{x-x_i}{h}\right)(x-x_i)}{\sum_{i=1}^{n} K'\left(\frac{x-x_i}{h}\right)}

Where  K' is the derivative of the kernel with respect to  x .

Point Update: We move each point in the dataset in the direction of the calculated mean shift vector:

 x_{t+1} = x_t + m(x_t)

Convergence: We repeat steps 3 and 4 until the points converge towards the local maxima of the density.

Once the Mean Shift Clustering is completed, we identify the “main shift” by observing the high density regions obtained. These regions represent the clusters or groups of data with the greatest concentration.

In summary, main shift with machine learning clustering refers to finding regions of maximum density in data through cluster analysis, using techniques such as Mean Shift Clustering to identify points of greatest concentration or change.

Mean shift clustering with scikit-learn

In this step, we generate a sample dataset using scikit-learn’s make_blobs() function. This function creates a specified number of random clusters, each centered at random locations in the feature space.

from sklearn.cluster import MeanShift
from sklearn.datasets import make_blobs
import matplotlib.pyplot as plt

X, _ = make_blobs(n_samples=300, centers=4, cluster_std=0.6, random_state=0)

# Visualize the generated data
plt.scatter(X[:, 0], X[:, 1], s=50)
plt.title('Generated Data')
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.show()

Executing you get the following result:

Main Shift Clustering - dataset

Now that we have some data we can apply the Main Shift to identify the clusters present

from sklearn.cluster import MeanShift, estimate_bandwidth
import numpy as np

# Generate sample data
X, _ = make_blobs(n_samples=300, centers=4, cluster_std=0.6, random_state=0)

# Estimate bandwidth
bandwidth = estimate_bandwidth(X, quantile=0.2, n_samples=len(X))

# Create MeanShift object
ms = MeanShift(bandwidth=bandwidth)

# Fit the model
ms.fit(X)

# Cluster labels
labels = ms.labels_

# Number of clusters obtained
n_clusters = len(np.unique(labels))

print("Number of clusters obtained:", n_clusters)

# Plot the results
plt.scatter(X[:, 0], X[:, 1], c=labels, cmap='viridis')
plt.title('Mean Shift Clustering')
plt.xlabel('Feature 1')
plt.ylabel('Feature 2')
plt.show()

First, we use scikit-learn’s estimate_bandwidth() function to estimate a good bandwidth window width based on the data. The quantile parameter controls how narrow or wide the bandwidth window will be relative to the data distribution. We create an instance of the scikit-learn MeanShift class and specify the width of the bandwidth window calculated in the previous step as the bandwidth parameter. We train the Mean Shift model using the X training data.

Running the code you get the expected result:

Main Shift Clustering - clusters chart

The bandwidth parameter

The “bandwidth” parameter is a key parameter in the Mean Shift algorithm that determines the scale at which data density is evaluated during the clustering process. In other words, it controls the distance at which surrounding points are considered when calculating the data density at a given point.

The need for the “bandwidth” parameter arises from the fact that Mean Shift is based on finding local maxima in estimating data density. Too little bandwidth can lead to the data being split too thinly into numerous small clusters, while too large a bandwidth can lead to a loss of detail and overlapping clusters.

In essence, the width of the bandwidth window directly affects the sensitivity of the Mean Shift model in recognizing clusters in the data. An appropriate bandwidth value is therefore crucial to obtain correct cluster recognition. If the bandwidth window width is set correctly, the algorithm will be able to identify clusters based on the local density of the data, correctly capturing variations in the data and grouping together points that share a similar density.

In practice, it is often advisable to estimate the bandwidth window width using techniques such as automatic estimation via scikit-learn’s estimate_bandwidth() function, as in the case of the example I showed you earlier. This allows you to adapt the bandwidth window width based on the specific data distribution and can significantly improve the performance of Mean Shift clustering.

Evaluation of the results obtained

It is possible to calculate the internal validity coefficient as the Silhouette index to evaluate the clustering results obtained from the example code we used. The Silhouette index measures cohesion within clusters and separation between clusters. We can calculate the Silhouette index using the scikit-learn library.

Here’s how you might do it in the context of our example code:

from sklearn.metrics import silhouette_score

silhouette_avg = silhouette_score(X, labels)
print("Silhouette Score:", silhouette_avg)

In this code, X represents our dataset, and labels are the cluster labels assigned by the Mean Shift algorithm. We use scikit-learn’s silhouette_score() function to calculate the Silhouette index. Running we get:

Silhouette Score: 0.6819938690643478

The Silhouette Score value obtained, which is approximately 0.682, is generally considered a good result. However, evaluating the validity of a Silhouette Score depends on the specific context of the data and may vary based on the nature of the data and the clustering problem.

Here is a general guide to interpreting the Silhouette Score:

  • A Silhouette Score close to 1 indicates that the clusters have been well separated and that the data points have been consistently assigned to their clusters.
  • A Silhouette Score near 0 indicates that clusters may have significant overlap or that data points may be close to the boundaries between clusters.
  • A Silhouette Score close to -1 indicates that the data points may have been assigned to the wrong cluster.

So, in your case, a Silhouette Score of around 0.682 suggests that the clusters found by the Mean Shift algorithm have good separation and internal consistency. However, it is important to note that the Silhouette Score alone may not be sufficient to fully determine the validity of the clusters. It is always a good idea to also consider other evaluation metrics and visually analyze the clustering results to get a complete understanding of the quality of the clustering achieved.

There are two other indices for evaluating the quality of the results. In fact, the Davies-Bouldin coefficient and Dunn index can also be calculated to further evaluate the clustering results.

Here’s how to calculate them using the scikit-learn library and the scikit-learn-extra library for the Dunn index:

from sklearn.metrics import davies_bouldin_score
from sklearn_extra.cluster import KMedoids
from sklearn.metrics import pairwise_distances
from sklearn.metrics import silhouette_score
import numpy as np

# Calculation of the Davies-Bouldin coefficient
davies_bouldin = davies_bouldin_score(X, labels)
print("Davies-Bouldin Score:", davies_bouldin)

# Calculation of Dunn's index
def dunn_index(X, labels):
    clusters = np.unique(labels)
    max_diameter = max([np.max(pairwise_distances(X[labels == label])) for label in clusters])
    min_intercluster_distance = min([np.min(pairwise_distances(X[labels == label], X[labels != label])) for label in clusters])
    return min_intercluster_distance / max_diameter

dunn = dunn_index(X, labels)
print("Dunn Index:", dunn)

In this code, we are calculating the Davies-Bouldin coefficient using scikit-learn’s davies_bouldin_score() function and the Dunn index using the dunn_index() function we have defined. Both of these values will give us further information on the validity of the clustering results. A lower value for the Davies-Bouldin coefficient and a higher value for the Dunn index indicate better clustering quality. Executing you get:

Davies-Bouldin Score: 0.4375640078237839
Dunn Index: 0.20231427477727162

The Davies-Bouldin coefficient measures the separation between clusters and the cohesion within each cluster. A lower value indicates better separation between clusters and greater cohesion within each cluster. In our case, the value of around 0.438 suggests that the clusters have good separation and cohesion, which is good.

The Dunn index is a measure of separation between clusters versus dispersion within clusters. A higher value indicates greater separation between clusters and less dispersion within clusters. In your case, the value of around 0.202 suggests that the separation between clusters is relatively low compared to the dispersion within clusters. However, it is important to note that the interpretation of the Dunn index can be influenced by the scale of the data and the nature of the clusters.

When to use Main Shift?

Mean Shift is a powerful and flexible clustering algorithm, but it is not necessarily the best choice for every data type or clustering problem. However, there are some situations where Mean Shift can be particularly effective compared to other clustering algorithms. Here are some of them:

  • Non-convex-shaped clusters: Mean Shift is effective at finding non-convex-shaped or irregularly shaped clusters. Because it makes no assumptions about the shape of the clusters, it can handle complex and more flexible shapes than assumption-based methods like K-Means.
  • Clusters of different sizes: Mean Shift does not require specifying the number of clusters a priori and can automatically find clusters of different sizes. It is especially useful when clusters have significant variations in density and size.
  • Variable density clusters: Mean Shift is sensitive to data density and can effectively handle variable density clusters. It can adapt the width of the band window to reflect changing data density, allowing it to correctly identify clusters even in regions of different density.
  • Few hyperparameters to tune: Mean Shift has only one main hyperparameter to tune, which is the bandwidth window width. This makes it relatively simple to use and less sensitive to the choice of parameters than other clustering algorithms that require specifying the number of clusters.

However, Mean Shift may not be the best choice in all situations. For example:

  • Large dataset: Mean Shift can be computationally expensive on very large datasets due to its computational complexity. In this case, more efficient algorithms such as K-Means or DBSCAN might be preferable.
  • Clusters of Defined Global Shape: If clusters are of defined global shape and are clearly separable, distance-based algorithms such as K-Means or hierarchical algorithms can produce more efficient and interpretable results.
  • Data with significant noise: If the data contains a high level of noise or anomalous points, Mean Shift may have difficulty handling them effectively. In this case, DBSCAN or other density-based clustering algorithms may be more suitable.

Leave a Reply