Skip to content

Commit

Permalink
[onert/python] Introduce metric API (#14544)
Browse files Browse the repository at this point in the history
This commit introduces metric APIs.
  - categorical_accuracy.py : CategoricalAccuracy class
  - metric : Metric class to be provided custom metric

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jan 15, 2025
1 parent 237f864 commit a2c1f28
Show file tree
Hide file tree
Showing 4 changed files with 110 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .metric import Metric
from .categorical_accuracy import CategoricalAccuracy

__all__ = ["Metric", "CategoricalAccuracy"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
from .metric import Metric


class CategoricalAccuracy(Metric):
"""
Metric for computing categorical accuracy.
"""
def __init__(self):
self.correct = 0
self.total = 0
self.axis = 0

def reset_state(self):
"""
Reset the metric's state.
"""
self.correct = 0
self.total = 0

def update_state(self, outputs, expecteds):
"""
Update the metric's state based on the outputs and expecteds.
Args:
outputs (list of np.ndarray): List of model outputs for each output layer.
expecteds (list of np.ndarray): List of expected ground truth values for each output layer.
"""
if len(outputs) != len(expecteds):
raise ValueError(
"The number of outputs and expecteds must match. "
f"Got {len(outputs)} outputs and {len(expecteds)} expecteds.")

for output, expected in zip(outputs, expecteds):
if output.shape[self.axis] != expected.shape[self.axis]:
raise ValueError(
f"Output and expected shapes must match along the specified axis {self.axis}. "
f"Got output shape {output.shape} and expected shape {expected.shape}."
)

batch_size = output.shape[self.axis]
for b in range(batch_size):
output_idx = np.argmax(output[b])
expected_idx = np.argmax(expected[b])
if output_idx == expected_idx:
self.correct += 1
self.total += batch_size

def result(self):
"""
Compute and return the final metric value.
Returns:
float: Metric value.
"""
if self.total == 0:
return 0.0
return self.correct / self.total
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
class Metric:
"""
Abstract base class for all metrics.
"""
def reset_state(self):
"""
Reset the metric's state.
"""
raise NotImplementedError

def update_state(self, outputs, expecteds):
"""
Update the metric's state based on the outputs and expecteds.
Args:
outputs (np.ndarray): Model outputs.
expecteds (np.ndarray): Expected ground truth values.
"""
raise NotImplementedError

def result(self):
"""
Compute and return the final metric value.
Returns:
float: Metric value.
"""
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from .categorical_accuracy import CategoricalAccuracy


class MetricsRegistry:
"""
Registry for creating metrics by name.
"""
_metrics = {
"categorical_accuracy": CategoricalAccuracy,
}

@staticmethod
def create_metric(name):
"""
Create a metric instance by name.
Args:
name (str): Name of the metric.
Returns:
BaseMetric: Metric instance.
"""
if name not in MetricsRegistry._metrics:
raise ValueError(
f"Unknown Metric: {name}. Custom metric is not supported yet")
return MetricsRegistry._metrics[name]()

0 comments on commit a2c1f28

Please sign in to comment.