From a2c1f28b1f0f9fa4cd8e438841d3128a8418ec6b Mon Sep 17 00:00:00 2001 From: Jang Jiseob Date: Wed, 15 Jan 2025 13:57:50 +0900 Subject: [PATCH] [onert/python] Introduce metric API (#14544) This commit introduces metric APIs. - categorical_accuracy.py : CategoricalAccuracy class - metric : Metric class to be provided custom metric ONE-DCO-1.0-Signed-off-by: ragmani --- .../experimental/train/metrics/__init__.py | 4 ++ .../train/metrics/categorical_accuracy.py | 56 +++++++++++++++++++ .../experimental/train/metrics/metric.py | 26 +++++++++ .../experimental/train/metrics/registry.py | 24 ++++++++ 4 files changed, 110 insertions(+) create mode 100644 runtime/onert/api/python/package/experimental/train/metrics/__init__.py create mode 100644 runtime/onert/api/python/package/experimental/train/metrics/categorical_accuracy.py create mode 100644 runtime/onert/api/python/package/experimental/train/metrics/metric.py create mode 100644 runtime/onert/api/python/package/experimental/train/metrics/registry.py diff --git a/runtime/onert/api/python/package/experimental/train/metrics/__init__.py b/runtime/onert/api/python/package/experimental/train/metrics/__init__.py new file mode 100644 index 00000000000..7ec5015b1e5 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/metrics/__init__.py @@ -0,0 +1,4 @@ +from .metric import Metric +from .categorical_accuracy import CategoricalAccuracy + +__all__ = ["Metric", "CategoricalAccuracy"] diff --git a/runtime/onert/api/python/package/experimental/train/metrics/categorical_accuracy.py b/runtime/onert/api/python/package/experimental/train/metrics/categorical_accuracy.py new file mode 100644 index 00000000000..b9a1f806a5b --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/metrics/categorical_accuracy.py @@ -0,0 +1,56 @@ +import numpy as np +from .metric import Metric + + +class CategoricalAccuracy(Metric): + """ + Metric for computing categorical accuracy. + """ + def __init__(self): + self.correct = 0 + self.total = 0 + self.axis = 0 + + def reset_state(self): + """ + Reset the metric's state. + """ + self.correct = 0 + self.total = 0 + + def update_state(self, outputs, expecteds): + """ + Update the metric's state based on the outputs and expecteds. + Args: + outputs (list of np.ndarray): List of model outputs for each output layer. + expecteds (list of np.ndarray): List of expected ground truth values for each output layer. + """ + if len(outputs) != len(expecteds): + raise ValueError( + "The number of outputs and expecteds must match. " + f"Got {len(outputs)} outputs and {len(expecteds)} expecteds.") + + for output, expected in zip(outputs, expecteds): + if output.shape[self.axis] != expected.shape[self.axis]: + raise ValueError( + f"Output and expected shapes must match along the specified axis {self.axis}. " + f"Got output shape {output.shape} and expected shape {expected.shape}." + ) + + batch_size = output.shape[self.axis] + for b in range(batch_size): + output_idx = np.argmax(output[b]) + expected_idx = np.argmax(expected[b]) + if output_idx == expected_idx: + self.correct += 1 + self.total += batch_size + + def result(self): + """ + Compute and return the final metric value. + Returns: + float: Metric value. + """ + if self.total == 0: + return 0.0 + return self.correct / self.total diff --git a/runtime/onert/api/python/package/experimental/train/metrics/metric.py b/runtime/onert/api/python/package/experimental/train/metrics/metric.py new file mode 100644 index 00000000000..619e914a852 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/metrics/metric.py @@ -0,0 +1,26 @@ +class Metric: + """ + Abstract base class for all metrics. + """ + def reset_state(self): + """ + Reset the metric's state. + """ + raise NotImplementedError + + def update_state(self, outputs, expecteds): + """ + Update the metric's state based on the outputs and expecteds. + Args: + outputs (np.ndarray): Model outputs. + expecteds (np.ndarray): Expected ground truth values. + """ + raise NotImplementedError + + def result(self): + """ + Compute and return the final metric value. + Returns: + float: Metric value. + """ + raise NotImplementedError diff --git a/runtime/onert/api/python/package/experimental/train/metrics/registry.py b/runtime/onert/api/python/package/experimental/train/metrics/registry.py new file mode 100644 index 00000000000..a7d65a2eee0 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/metrics/registry.py @@ -0,0 +1,24 @@ +from .categorical_accuracy import CategoricalAccuracy + + +class MetricsRegistry: + """ + Registry for creating metrics by name. + """ + _metrics = { + "categorical_accuracy": CategoricalAccuracy, + } + + @staticmethod + def create_metric(name): + """ + Create a metric instance by name. + Args: + name (str): Name of the metric. + Returns: + BaseMetric: Metric instance. + """ + if name not in MetricsRegistry._metrics: + raise ValueError( + f"Unknown Metric: {name}. Custom metric is not supported yet") + return MetricsRegistry._metrics[name]()