Skip to content

Commit

Permalink
Merge pull request #286 from ZJUEarthData/dev/Yongkang
Browse files Browse the repository at this point in the history
feat: extract clustering public functions
  • Loading branch information
SanyHe authored Dec 12, 2023
2 parents 5457e00 + b9a1128 commit 48f0e9e
Show file tree
Hide file tree
Showing 3 changed files with 147 additions and 42 deletions.
104 changes: 62 additions & 42 deletions geochemistrypi/data_mining/model/clustering.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
# -*- coding: utf-8 -*-
import json
import os
from typing import Dict, Optional, Union

import mlflow
import numpy as np
import pandas as pd
from rich import print
from sklearn import metrics
from sklearn.cluster import DBSCAN, AffinityPropagation, KMeans

from ..constants import MLFLOW_ARTIFACT_DATA_PATH, MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH
from ..utils.base import clear_output, save_data, save_fig
from ..utils.base import clear_output, save_data, save_fig, save_text
from ._base import WorkflowBase
from .func.algo_clustering._common import plot_results, plot_silhouette_diagram, score
from .func.algo_clustering._dbscan import dbscan_manual_hyper_parameters, dbscan_result_plot
from .func.algo_clustering._kmeans import kmeans_manual_hyper_parameters, plot_silhouette_diagram, scatter2d, scatter3d
from .func.algo_clustering._kmeans import kmeans_manual_hyper_parameters, scatter2d, scatter3d


class ClusteringWorkflowBase(WorkflowBase):
Expand Down Expand Up @@ -52,6 +53,63 @@ def get_labels(self):
GEOPI_OUTPUT_ARTIFACTS_DATA_PATH = os.getenv("GEOPI_OUTPUT_ARTIFACTS_DATA_PATH")
save_data(self.clustering_result, f"{self.naming} Result", GEOPI_OUTPUT_ARTIFACTS_DATA_PATH, MLFLOW_ARTIFACT_DATA_PATH)

@staticmethod
def _score(data: pd.DataFrame, labels: pd.DataFrame, algorithm_name: str, store_path: str) -> None:
"""Calculate the score of the model."""
print("-----* Model Score *-----")
scores = score(data, labels)
scores_str = json.dumps(scores, indent=4)
save_text(scores_str, f"Model Score - {algorithm_name}", store_path)
mlflow.log_metrics(scores)

@staticmethod
def _plot_results(data: pd.DataFrame, labels: pd.DataFrame, algorithm_name: str, cluster_centers_: pd.DataFrame, local_path: str, mlflow_path: str) -> None:
"""Plot the cluster_results ."""
print("-----* results diagram *-----")
plot_results(data, labels, algorithm_name, cluster_centers_)
save_fig(f"results - {algorithm_name}", local_path, mlflow_path)
data = pd.concat([data, labels], axis=1)
save_data(data, f"results - {algorithm_name}", local_path, mlflow_path)

@staticmethod
def _plot_silhouette_diagram(data: pd.DataFrame, labels: pd.DataFrame, cluster_centers_: pd.DataFrame, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
"""Plot the silhouette diagram of the clustering result."""
print("-----* Silhouette Diagram *-----")
plot_silhouette_diagram(data, labels, algorithm_name)
save_fig(f"Silhouette Diagram - {algorithm_name}", local_path, mlflow_path)
data_with_labels = pd.concat([data, labels], axis=1)
save_data(data_with_labels, "Silhouette Diagram - Data With Labels", local_path, mlflow_path)
if isinstance(cluster_centers_, pd.DataFrame):
cluster_center_data = pd.DataFrame(cluster_centers_, columns=data.columns)
save_data(cluster_center_data, "Silhouette Diagram - Cluster Centers", local_path, mlflow_path)

def common_components(self) -> None:
"""Invoke all common application functions for clustering algorithms."""
GEOPI_OUTPUT_METRICS_PATH = os.getenv("GEOPI_OUTPUT_METRICS_PATH")
GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH = os.getenv("GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH")
self._score(
data=self.X,
labels=self.clustering_result["clustering result"],
algorithm_name=self.naming,
store_path=GEOPI_OUTPUT_METRICS_PATH,
)
self._plot_results(
data=self.X,
labels=self.clustering_result["clustering result"],
cluster_centers_=self.get_cluster_centers(),
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)
self._plot_silhouette_diagram(
data=self.X,
labels=self.clustering_result["clustering result"],
cluster_centers_=self.get_cluster_centers(),
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)


class KMeansClustering(ClusteringWorkflowBase):
"""The automation workflow of using KMeans algorithm to make insightful products."""
Expand Down Expand Up @@ -176,35 +234,6 @@ def manual_hyper_parameters(cls) -> Dict:
clear_output()
return hyper_parameters

def _get_scores(self):
"""Get the scores of the clustering result."""
print("-----* KMeans Scores *-----")
print("Inertia Score: ", self.model.inertia_)
print("Calinski Harabasz Score: ", metrics.calinski_harabasz_score(self.X, self.model.labels_))
print("Silhouette Score: ", metrics.silhouette_score(self.X, self.model.labels_))
mlflow.log_metric("Inertia Score", self.model.inertia_)
mlflow.log_metric("Calinski Harabasz Score", metrics.calinski_harabasz_score(self.X, self.model.labels_))
mlflow.log_metric("Silhouette Score", metrics.silhouette_score(self.X, self.model.labels_))

@staticmethod
def _plot_silhouette_diagram(
data: pd.DataFrame,
cluster_labels: pd.DataFrame,
cluster_centers_: np.ndarray,
n_clusters: int,
algorithm_name: str,
local_path: str,
mlflow_path: str,
) -> None:
"""Plot the silhouette diagram of the clustering result."""
print("-----* Silhouette Diagram *-----")
plot_silhouette_diagram(data, cluster_labels, cluster_centers_, n_clusters, algorithm_name)
save_fig(f"Silhouette Diagram - {algorithm_name}", local_path, mlflow_path)
data_with_labels = pd.concat([data, cluster_labels], axis=1)
save_data(data_with_labels, "Silhouette Diagram - Data With Labels", local_path, mlflow_path)
cluster_center_data = pd.DataFrame(cluster_centers_, columns=data.columns)
save_data(cluster_center_data, "Silhouette Diagram - Cluster Centers", local_path, mlflow_path)

@staticmethod
def _scatter2d(data: pd.DataFrame, cluster_labels: pd.DataFrame, algorithm_name: str, local_path: str, mlflow_path: str) -> None:
"""Plot the two-dimensional diagram of the clustering result."""
Expand All @@ -226,16 +255,7 @@ def _scatter3d(data: pd.DataFrame, cluster_labels: pd.DataFrame, algorithm_name:
def special_components(self, **kwargs: Union[Dict, np.ndarray, int]) -> None:
"""Invoke all special application functions for this algorithms by Scikit-learn framework."""
GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH = os.getenv("GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH")
self._get_scores()
self._plot_silhouette_diagram(
data=self.X,
cluster_labels=self.clustering_result["clustering result"],
cluster_centers_=self.get_cluster_centers(),
n_clusters=self.n_clusters,
algorithm_name=self.naming,
local_path=GEOPI_OUTPUT_ARTIFACTS_IMAGE_MODEL_OUTPUT_PATH,
mlflow_path=MLFLOW_ARTIFACT_IMAGE_MODEL_OUTPUT_PATH,
)

# Draw graphs when the number of principal components > 3
if self.X.shape[1] >= 3:
# choose two of dimensions to draw
Expand Down
82 changes: 82 additions & 0 deletions geochemistrypi/data_mining/model/func/algo_clustering/_common.py
Original file line number Diff line number Diff line change
@@ -1 +1,83 @@
# -*- coding: utf-8 -*-
from typing import Dict

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from rich import print
from sklearn.metrics import calinski_harabasz_score, silhouette_samples, silhouette_score


def score(X: pd.DataFrame, labels: pd.DataFrame) -> Dict:
"""Calculate the scores of the clustering model.
Parameters
----------
X : pd.DataFrame (n_samples, n_components)
The true values.
label : pd.DataFrame (n_samples, n_components)
The labels values.
Returns
-------
scores : dict
The scores of the clustering model.
"""
silhouette = silhouette_score(X, labels)
calinski_harabaz = calinski_harabasz_score(X, labels)
print("silhouette_score: ", silhouette)
print("calinski_harabasz_score:", calinski_harabaz)
scores = {
"silhouette_score": silhouette,
"calinski_harabasz_score": calinski_harabaz,
}
return scores


def plot_results(X, labels, algorithm_name: str, cluster_centers_=None) -> None:
"""Plot clustering results of the clustering model.
Parameters
----------
X : pd.DataFrame (n_samples, n_components)
The true values.
label : pd.DataFrame (n_samples, n_components)
The labels values.
algorithm_name : str
The name of the algorithm model.
cluster_centers
The center of the algorithm model.
"""
sns.scatterplot(x=X.iloc[:, 0], y=X.iloc[:, 1], hue=labels, palette="viridis", s=50, alpha=0.8)
if not isinstance(cluster_centers_, str):
plt.scatter(cluster_centers_[:, 0], cluster_centers_[:, 1], c="red", marker="X", s=200, label="Cluster Centers")
plt.title(f"results - {algorithm_name}")
plt.xlabel("Feature 1")
plt.ylabel("Feature 2")
plt.legend()


def plot_silhouette_diagram(X, labels, algorithm_name: str):
"""Calculate the scores of the clustering model.
Parameters
----------
X : pd.DataFrame (n_samples, n_components)
The true values.
label : pd.DataFrame (n_samples, n_components)
The labels values.
algorithm_name : str
The name of the algorithm model.
"""
silhouette_values = silhouette_samples(X, labels)
sns.histplot(silhouette_values, bins=30, kde=True)
plt.title(f"Silhouette Diagram - {algorithm_name}")
plt.xlabel("Silhouette Coefficient")
plt.ylabel("Frequency")
plt.legend()
3 changes: 3 additions & 0 deletions geochemistrypi/data_mining/process/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def activate(
# Save the model hyper-parameters
self.clt_workflow.save_hyper_parameters(hyper_parameters, self.model_name, os.getenv("GEOPI_OUTPUT_PARAMETERS_PATH"))

# Common components for every clustering algorithm
self.clt_workflow.common_components()

# special components of different algorithms
self.clt_workflow.special_components()

Expand Down

0 comments on commit 48f0e9e

Please sign in to comment.