Skip to content

Commit

Permalink
more
Browse files Browse the repository at this point in the history
  • Loading branch information
gagolews committed Feb 4, 2025
1 parent bbbf949 commit 5963845
Showing 1 changed file with 72 additions and 24 deletions.
96 changes: 72 additions & 24 deletions .devel/mst_anal.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,43 @@
import pandas as pd
import matplotlib.pyplot as plt
import clustbench

import os.path


# from sklearn.datasets import make_blobs
# X, labels = make_blobs(
# n_samples=[1000, 100, 100],
# cluster_std=1,
# random_state=42
# )
# skiplist = [1197, 1198]


examples = [
["sipu", "aggregation", [785, 784, 786]],
["sipu", "pathbased", [293, 271, 294]],
]

example = examples[0]
data_path = os.path.join("~", "Projects", "clustering-data-v1")
b = clustbench.load_dataset("sipu", "pathbased", path=data_path)
X = b.data
#labels = b.labels[0]
np.random.seed(123)
b = clustbench.load_dataset(example[0], example[1], path=data_path)
X, labels = b.data, b.labels[0]
skiplist = example[2]



n = X.shape[0]
min_cluster_size = n/20

mst = genieclust.internal.mst_from_distance(X, "euclidean")
mst_w, mst_e = mst
n = len(mst_w)+1
adj_list = [ [] for i in range(n) ]
for i in range(n-1):
adj_list[mst_e[i, 0]].append(i)
adj_list[mst_e[i, 1]].append(i)

for i in range(n):
adj_list[i] = np.array(adj_list[i])

def visit(v, e, c): # v->w where mst_e[e,:]={v,w}
if mst_labels[e] < 0: # skiplist
Expand All @@ -37,11 +58,8 @@ def visit(v, e, c): # v->w where mst_e[e,:]={v,w}
return tot



min_cluster_size = n/20
# TODO: store mst_s in each iteration
# TODO: restart from the vertices adjacent to the skipped edge
skiplist = [293, 271]
mst_s = np.zeros((n-1, 2), dtype=int)
labels = np.zeros(n, dtype=int)
mst_labels = np.zeros(n-1, dtype=int)
Expand All @@ -66,30 +84,42 @@ def visit(v, e, c): # v->w where mst_e[e,:]={v,w}
#(mst_w[op])[min_mst_s[op]>min_cluster_size]
which_cut = op[np.nonzero(min_mst_s[op]>min_cluster_size)]
print(which_cut)


#
plt.clf()
#
ax1 = plt.subplot(2, 2, 1)
ax1 = plt.subplot(3, 2, 1)
ax1.plot(mst_w[op])
ax2 = ax1.twinx()
ax2.plot(np.arange(n-1), min_mst_s[op], c='orange', alpha=0.3)
for i in which_cut:
plt.text(op[i], min_mst_s[i], i, ha='center')
#
plt.subplot(2, 2, (2, 4))
genieclust.plots.plot_scatter(X, labels=labels)
genieclust.plots.plot_segments(mst_e, X, style="b-", alpha=0.1)
genieclust.plots.plot_segments(mst_e[min_mst_s>min_cluster_size,:], X, style="r-")
genieclust.plots.plot_segments(mst_e[mst_labels<0,:], X, style="w-")
plt.axis("equal")
for i in range(n-1):
plt.text(
(X[mst_e[i,0],0]+X[mst_e[i,1],0])/2,
(X[mst_e[i,0],1]+X[mst_e[i,1],1])/2,
"%d (%d)" % (i, min(mst_s[i, 0], mst_s[i, 1]))
)
# treelhouette
plt.subplot(3, 2, 5)
cluster_distances = np.ones((c, c))*np.inf
for e in skiplist:
i, j = mst_e[e, :]
cluster_distances[labels[i]-1, labels[j]-1] = mst_w[e]
cluster_distances[labels[j]-1, labels[i]-1] = mst_w[e]
# leave the diagonal to inf
min_intercluster_distances = np.min(cluster_distances, axis=0)
#
a = np.zeros(n)
for i in range(n):
a[i] = np.min(mst_w[adj_list[i][mst_labels[adj_list[i]] == labels[i]]])
b = min_intercluster_distances[labels-1]
s = np.where(a<b, 1.0 - a/b, b/a - 1.0)
#
o1 = np.argsort(s)[::-1]
o2 = np.argsort(labels[o1], kind='stable')
#plt.plot(s[o1][o2])
#plt.ylim(0, 1)
plt.bar(np.arange(n), s[o1][o2], width=1.0, color=np.array(genieclust.plots.col)[labels[o1]][o2])
#
ax1 = plt.subplot(2, 2, 3)
#
ax1 = plt.subplot(3, 2, 3)
ax2 = ax1.twinx()
last = 0
for i in range(1, c+1):
Expand All @@ -107,7 +137,25 @@ def visit(v, e, c): # v->w where mst_e[e,:]={v,w}
alpha=0.2
)
last += counts[i] - 1
#
# MST
plt.subplot(3, 2, (2, 6))
genieclust.plots.plot_scatter(X, labels=labels)
genieclust.plots.plot_segments(mst_e, X, style="b-", alpha=0.1)
genieclust.plots.plot_segments(mst_e[min_mst_s>min_cluster_size,:], X, alpha=0.5)
genieclust.plots.plot_segments(mst_e[mst_labels<0,:], X, style="w-")
plt.axis("equal")
for i in range(n-1):
plt.text(
(X[mst_e[i,0],0]+X[mst_e[i,1],0])/2,
(X[mst_e[i,0],1]+X[mst_e[i,1],1])/2,
"%d (%d)" % (i, min(mst_s[i, 0], mst_s[i, 1])),
color=genieclust.plots.col[mst_labels[i]]
)



# DBSCAN is non-adaptive - cannot detect clusters of different densities well
plt.violinplot([ mst_w[mst_labels==i] for i in range(1, c+1) ])
#plt.violinplot([ mst_w[mst_labels==i] for i in range(1, c+1) ])


0 comments on commit 5963845

Please sign in to comment.