Skip to content

Commit

Permalink
Merge pull request #45 from mathysgrapotte/linting
Browse files Browse the repository at this point in the history
Linting part one
  • Loading branch information
mathysgrapotte authored Jan 22, 2025
2 parents 30d76d9 + 842cb85 commit 5b57215
Show file tree
Hide file tree
Showing 28 changed files with 728 additions and 658 deletions.
1 change: 1 addition & 0 deletions src/stimulus/analysis/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Analysis package for stimulus, analysis_default is to be refactored, see git issues."""
99 changes: 63 additions & 36 deletions src/stimulus/analysis/analysis_default.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Default analysis module for stimulus package."""

import math
from typing import Any, Tuple
from typing import Any

import matplotlib
import matplotlib as mpl
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
Expand All @@ -18,14 +20,22 @@ class Analysis:
"""

@staticmethod
def get_grid_shape(n: int) -> Tuple[int, int]:
def get_grid_shape(n: int) -> tuple[int, int]:
"""Calculates rows and columns for a rectangle layout (flexible)."""
rows = int(math.ceil(math.sqrt(n))) # Round up the square root for rows
cols = int(math.ceil(n / rows)) # Calculate columns based on rows
return rows, cols

@staticmethod
def heatmap(data, row_labels, col_labels, ax=None, cbar_kw=None, cbarlabel="", **kwargs):
def heatmap(
data: np.ndarray,
row_labels: list[str],
col_labels: list[str],
ax: Any | None = None,
cbar_kw: dict | None = None,
cbarlabel: str = "",
**kwargs: Any,
) -> tuple[Any, Any]:
"""Create a heatmap from a numpy array and two lists of labels.
Parameters
Expand Down Expand Up @@ -80,7 +90,14 @@ def heatmap(data, row_labels, col_labels, ax=None, cbar_kw=None, cbarlabel="", *
return im, cbar

@staticmethod
def annotate_heatmap(im, data=None, valfmt="{x:.2f}", textcolors=("black", "white"), threshold=None, **textkw):
def annotate_heatmap(
im: Any,
data: np.ndarray | None = None,
valfmt: str = "{x:.2f}",
textcolors: tuple[str, str] = ("black", "white"),
threshold: float | None = None,
**textkw: Any,
) -> list[Any]:
"""A function to annotate a heatmap.
Parameters
Expand Down Expand Up @@ -108,19 +125,16 @@ def annotate_heatmap(im, data=None, valfmt="{x:.2f}", textcolors=("black", "whit
data = im.get_array()

# Normalize the threshold to the images color range.
if threshold is not None:
threshold = im.norm(threshold)
else:
threshold = im.norm(data.max()) / 2.0
threshold = im.norm(threshold) if threshold is not None else im.norm(data.max()) / 2.0

# Set default alignment to center, but allow it to be
# overwritten by textkw.
kw = dict(horizontalalignment="center", verticalalignment="center")
kw = {"horizontalalignment": "center", "verticalalignment": "center"}
kw.update(textkw)

# Get the formatter in case a string is supplied
if isinstance(valfmt, str):
valfmt = matplotlib.ticker.StrMethodFormatter(valfmt)
valfmt = mpl.ticker.StrMethodFormatter(valfmt)

# Loop over the data and create a `Text` for each "pixel".
# Change the text's color depending on the data.
Expand All @@ -142,11 +156,18 @@ class AnalysisPerformanceTune(Analysis):
TODO or maybe one pdf for all models with all metrics, colored by model. One for train, one for val.
"""

def __init__(self, results_path: str):
def __init__(self, results_path: str) -> None:
"""Initialize the AnalysisPerformanceTune class."""
super().__init__()
self.results = pd.read_csv(results_path)

def plot_metric_vs_iteration(self, metrics: list, figsize: tuple = (10, 10), output: str = None):
def plot_metric_vs_iteration(
self,
metrics: list,
figsize: tuple = (10, 10),
output: str | None = None,
) -> None:
"""Plot metrics vs iteration for training and validation."""
# create figure
rows, cols = self.get_grid_shape(len(metrics))
fig, axs = plt.subplots(rows, cols, figsize=figsize)
Expand All @@ -169,7 +190,7 @@ def plot_metric_vs_iteration(self, metrics: list, figsize: tuple = (10, 10), out
plt.savefig(output)
plt.show()

def plot_metric_vs_iteration_per_metric(self, ax: Any, metric: str):
def plot_metric_vs_iteration_per_metric(self, ax: Any, metric: str) -> Any:
"""Plot the metric vs the iteration."""
# plot training performance
ax.plot(
Expand Down Expand Up @@ -200,7 +221,8 @@ def plot_metric_vs_iteration_per_metric(self, ax: Any, metric: str):
class AnalysisRobustness(Analysis):
"""Report the robustness of the models."""

def __init__(self, metrics: list, experiment: object, batch_size: int):
def __init__(self, metrics: list, experiment: object, batch_size: int) -> None:
"""Initialize the AnalysisRobustness class."""
super().__init__()
self.metrics = metrics
self.experiment = experiment
Expand All @@ -209,11 +231,13 @@ def __init__(self, metrics: list, experiment: object, batch_size: int):
def get_performance_table(self, names: list, model_list: dict, data_list: list) -> pd.DataFrame:
"""Compute the performance metrics of each model on each dataset.
`names` is a list of names that identifies each model.
The corresponding dataset used to train the model will also be identified equally.
Args:
names: List of names that identifies each model.
model_list: Dictionary of models in same order as data_list.
data_list: List of datasets used for training.
`model_list` should have the same order as `data_list`.
So model_list[i] is obtained by training on data_list[i].
Returns:
DataFrame containing performance metrics.
"""
# check same length
if (len(names) != len(model_list)) and (len(names) != len(data_list)):
Expand Down Expand Up @@ -249,17 +273,17 @@ def get_performance_table_for_one_model(self, names: list, model: object, data_l
def get_average_performance_table(self, df: pd.DataFrame) -> pd.DataFrame:
"""Compute the average performance of each model on each dataset.
`df` containing the performance table of each model on each dataset.
"""
df = df[self.metrics + ["model"]]
df = df.groupby(["model"]).mean().reset_index()
return df

def plot_performance_heatmap(self, df: pd.DataFrame, figsize: tuple = (10, 10), output: str = None):
"""Plot the performance of each model on each dataset.
Args:
df: DataFrame containing the performance table.
`df` containing the performance table of each model on each dataset.
Returns:
DataFrame with averaged metrics.
"""
df = df[[*self.metrics, "model"]] # Use list unpacking instead of concatenation
return df.groupby(["model"]).mean().reset_index()

def plot_performance_heatmap(self, df: pd.DataFrame, figsize: tuple = (10, 10), output: str | None = None) -> None:
"""Plot the performance of each model on each dataset."""
# create figure
rows, cols = self.get_grid_shape(len(self.metrics))
fig, axs = plt.subplots(rows, cols, figsize=figsize)
Expand All @@ -279,19 +303,22 @@ def plot_performance_heatmap(self, df: pd.DataFrame, figsize: tuple = (10, 10),

# plot heatmap
im, cbar = self.heatmap(mat, mat.index, mat.columns, ax=ax, cmap="YlGn", cbarlabel=self.metrics[i])
texts = self.annotate_heatmap(im, valfmt="{x:.2f}")
self.annotate_heatmap(im, valfmt="{x:.2f}") # Don't assign to unused variable

# save plot
plt.tight_layout()
if output:
plt.savefig(output)
plt.show()

def plot_delta_performance(self, metric: str, df: pd.DataFrame, figsize: tuple = (10, 10), output: str = None):
"""Plot the delta performance of each model on each dataset, according to one specific metric.
`df` containing the performance table of each model on each dataset.
"""
def plot_delta_performance(
self,
metric: str,
df: pd.DataFrame,
figsize: tuple = (10, 10),
output: str | None = None,
) -> None:
"""Plot the delta performance of each model on each dataset."""
# create figure
rows, cols = self.get_grid_shape(len(df["model"].unique()))
fig, axs = plt.subplots(rows, cols, figsize=figsize)
Expand Down Expand Up @@ -321,7 +348,7 @@ def plot_delta_performance(self, metric: str, df: pd.DataFrame, figsize: tuple =
plt.savefig(output)
plt.show()

def plot_delta_performance_for_one_model(self, ax: Any, metric: str, df: pd.DataFrame, model_name: str):
def plot_delta_performance_for_one_model(self, ax: Any, metric: str, df: pd.DataFrame, model_name: str) -> Any:
"""Plot the delta performance of one model."""
df = self.parse_delta_performance_for_one_model(metric, df, model_name)

Expand All @@ -336,7 +363,7 @@ def plot_delta_performance_for_one_model(self, ax: Any, metric: str, df: pd.Data

return ax

def parse_delta_performance_for_one_model(self, metric: str, df: pd.DataFrame, model_name: str):
def parse_delta_performance_for_one_model(self, metric: str, df: pd.DataFrame, model_name: str) -> pd.DataFrame:
"""Compute the delta performance of one model."""
# filter data frame
df = df[["data", "model", metric]]
Expand Down
1 change: 1 addition & 0 deletions src/stimulus/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Command line interface package for the stimulus library."""
Loading

0 comments on commit 5b57215

Please sign in to comment.