-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathrun_clustering_chunked.py
71 lines (63 loc) · 2.18 KB
/
run_clustering_chunked.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import argparse
import os
from pathlib import Path
import numpy as np
from joblib import dump
from sklearn.cluster import KMeans, MiniBatchKMeans
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--embedded-chunk-paths",
type=str, nargs="+",
help="Paths to files containing chunks of the dataset with sentence representations",
)
parser.add_argument(
"--out-file-model", type=str, help="Path to file to save kmeans model"
)
parser.add_argument(
"--out-file-labels",
type=str,
help="Path to file to save cluster labels for the whole dataset",
)
parser.add_argument(
"--n-clusters", type=int, default=8, help="Number of clusters of k-means"
)
parser.add_argument(
"--num-epochs",
type=int,
default=100,
help="For how many epochs to run the k-means algorithm?",
)
parser.add_argument(
"--verbose", type=int, default=0, help="Verbosity"
)
parser.add_argument(
"--random-state", type=int, default=42, help="Random state"
)
args = parser.parse_args()
model = MiniBatchKMeans(
n_clusters=args.n_clusters,
compute_labels=True,
verbose=args.verbose,
random_state=args.random_state,
)
# fit
for i_epoch in range(args.num_epochs):
print(f"{i_epoch=}")
for subdataset_path in args.embedded_chunk_paths:
features = np.load(subdataset_path)["arr_0"]
model.partial_fit(features)
del features
print(f"{model.inertia_=}")
Path(os.path.dirname(args.out_file_model)).mkdir(parents=True, exist_ok=True)
dump(model, args.out_file_model)
print(f"Saved model to {args.out_file_model}")
# label
all_labels = []
for subdataset_path in args.embedded_chunk_paths:
features = np.load(subdataset_path)["arr_0"]
all_labels.extend(model.predict(features))
del features
Path(os.path.dirname(args.out_file_labels)).mkdir(parents=True, exist_ok=True)
np.savetxt(args.out_file_labels, np.asarray(all_labels), fmt="%i")
print(f"Saved labels to {args.out_file_labels}")