Skip to content

Commit

Permalink
KMedoids supports array-like init method (#137)
Browse files Browse the repository at this point in the history
* KMedoids supports array-like init method

* make input X as optional

* remove unnecessary data type checking

* fix formatting according to flake8 suggestions

* fix format

* add KMedoids array-like init test case for max_iter=0

* add test cases for n_clusters > number of centroids when using array-like init method

* lint formatting
  • Loading branch information
makcedward authored Jan 5, 2022
1 parent b04ec47 commit cb70ef0
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 6 deletions.
35 changes: 29 additions & 6 deletions sklearn_extra/cluster/_k_medoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,8 @@ class KMedoids(BaseEstimator, ClusterMixin, TransformerMixin):
method : {'alternate', 'pam'}, default: 'alternate'
Which algorithm to use. 'alternate' is faster while 'pam' is more accurate.
init : {'random', 'heuristic', 'k-medoids++', 'build'}, optional, default: 'heuristic'
init : {'random', 'heuristic', 'k-medoids++', 'build'}, or array-like of shape
(n_clusters, n_features), optional, default: 'heuristic'
Specify medoid initialization method. 'random' selects n_clusters
elements from the dataset. 'heuristic' picks the n_clusters points
with the smallest sum distance to every other point. 'k-medoids++'
Expand All @@ -74,6 +75,8 @@ class KMedoids(BaseEstimator, ClusterMixin, TransformerMixin):
algorithm. Often 'build' is more efficient but slower than other
initializations on big datasets and it is also very non-robust,
if there are outliers in the dataset, use another initialization.
If an array is passed, it should be of shape (n_clusters, n_features)
and gives the initial centers.
.. _k-means++: https://theory.stanford.edu/~sergei/papers/kMeansPP-soda.pdf
Expand Down Expand Up @@ -181,13 +184,29 @@ def _check_init_args(self):

# Check init
init_methods = ["random", "heuristic", "k-medoids++", "build"]
if self.init not in init_methods:
if not (
hasattr(self.init, "__array__")
or (isinstance(self.init, str) and self.init in init_methods)
):
raise ValueError(
"init needs to be one of "
+ "the following: "
+ "%s" % init_methods
+ "%s" % (init_methods + ["array-like"])
)

# Check n_clusters
if (
hasattr(self.init, "__array__")
and self.n_clusters != self.init.shape[0]
):
warnings.warn(
"n_clusters should be equal to size of array-like if init "
"is array-like setting n_clusters to {}.".format(
self.init.shape[0]
)
)
self.n_clusters = self.init.shape[0]

def fit(self, X, y=None):
"""Fit K-Medoids to the provided data.
Expand Down Expand Up @@ -219,7 +238,7 @@ def fit(self, X, y=None):
D = pairwise_distances(X, metric=self.metric)

medoid_idxs = self._initialize_medoids(
D, self.n_clusters, random_state_
D, self.n_clusters, random_state_, X
)
labels = None

Expand Down Expand Up @@ -407,10 +426,14 @@ def predict(self, X):

return pd_argmin

def _initialize_medoids(self, D, n_clusters, random_state_):
def _initialize_medoids(self, D, n_clusters, random_state_, X=None):
"""Select initial mediods when beginning clustering."""

if self.init == "random": # Random initialization
if hasattr(self.init, "__array__"): # Pre assign cluster
medoids = np.hstack(
[np.where((X == c).all(axis=1)) for c in self.init]
).ravel()
elif self.init == "random": # Random initialization
# Pick random k medoids as the initial ones.
medoids = random_state_.choice(len(D), n_clusters, replace=False)
elif self.init == "k-medoids++":
Expand Down
26 changes: 26 additions & 0 deletions sklearn_extra/cluster/tests/test_k_medoids.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,33 @@ def test_medoids_indices():

model = KMedoids(n_clusters=3, init="build", random_state=rng)

centroids = np.array([X_iris[0], X_iris[50]])
array_like_model = KMedoids(
n_clusters=len(centroids), init=centroids, max_iter=0
)

model.fit(X_iris)
clara.fit(X_iris)
array_like_model.fit(X_iris)
assert_array_equal(X_iris[model.medoid_indices_], model.cluster_centers_)
assert_array_equal(X_iris[clara.medoid_indices_], clara.cluster_centers_)
assert_array_equal(centroids, array_like_model.cluster_centers_)


def test_array_like_init():
centroids = np.array([X_cc[0], X_cc[50]])

expected = np.hstack([np.zeros(50), np.ones(50)])
km = KMedoids(n_clusters=len(centroids), init=centroids)
km.fit(X_cc)
# # This test use data that are not perfectly separable so the
# # accuracy is not 1. Accuracy around 0.85
assert (np.mean(km.labels_ == expected) > 0.8) or (
1 - np.mean(km.labels_ == expected) > 0.8
)

# Override n_clusters if array-like init method is used
km = KMedoids(n_clusters=len(centroids) + 2, init=centroids)
km.fit(X_cc)

assert len(km.cluster_centers_) == len(centroids)

0 comments on commit cb70ef0

Please sign in to comment.