-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
This commit introduces metric APIs. - categorical_accuracy.py : CategoricalAccuracy class - metric : Metric class to be provided custom metric ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
- Loading branch information
Showing
4 changed files
with
110 additions
and
0 deletions.
There are no files selected for viewing
4 changes: 4 additions & 0 deletions
4
runtime/onert/api/python/package/experimental/train/metrics/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .metric import Metric | ||
from .categorical_accuracy import CategoricalAccuracy | ||
|
||
__all__ = ["Metric", "CategoricalAccuracy"] |
56 changes: 56 additions & 0 deletions
56
runtime/onert/api/python/package/experimental/train/metrics/categorical_accuracy.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import numpy as np | ||
from .metric import Metric | ||
|
||
|
||
class CategoricalAccuracy(Metric): | ||
""" | ||
Metric for computing categorical accuracy. | ||
""" | ||
def __init__(self): | ||
self.correct = 0 | ||
self.total = 0 | ||
self.axis = 0 | ||
|
||
def reset_state(self): | ||
""" | ||
Reset the metric's state. | ||
""" | ||
self.correct = 0 | ||
self.total = 0 | ||
|
||
def update_state(self, outputs, expecteds): | ||
""" | ||
Update the metric's state based on the outputs and expecteds. | ||
Args: | ||
outputs (list of np.ndarray): List of model outputs for each output layer. | ||
expecteds (list of np.ndarray): List of expected ground truth values for each output layer. | ||
""" | ||
if len(outputs) != len(expecteds): | ||
raise ValueError( | ||
"The number of outputs and expecteds must match. " | ||
f"Got {len(outputs)} outputs and {len(expecteds)} expecteds.") | ||
|
||
for output, expected in zip(outputs, expecteds): | ||
if output.shape[self.axis] != expected.shape[self.axis]: | ||
raise ValueError( | ||
f"Output and expected shapes must match along the specified axis {self.axis}. " | ||
f"Got output shape {output.shape} and expected shape {expected.shape}." | ||
) | ||
|
||
batch_size = output.shape[self.axis] | ||
for b in range(batch_size): | ||
output_idx = np.argmax(output[b]) | ||
expected_idx = np.argmax(expected[b]) | ||
if output_idx == expected_idx: | ||
self.correct += 1 | ||
self.total += batch_size | ||
|
||
def result(self): | ||
""" | ||
Compute and return the final metric value. | ||
Returns: | ||
float: Metric value. | ||
""" | ||
if self.total == 0: | ||
return 0.0 | ||
return self.correct / self.total |
26 changes: 26 additions & 0 deletions
26
runtime/onert/api/python/package/experimental/train/metrics/metric.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
class Metric: | ||
""" | ||
Abstract base class for all metrics. | ||
""" | ||
def reset_state(self): | ||
""" | ||
Reset the metric's state. | ||
""" | ||
raise NotImplementedError | ||
|
||
def update_state(self, outputs, expecteds): | ||
""" | ||
Update the metric's state based on the outputs and expecteds. | ||
Args: | ||
outputs (np.ndarray): Model outputs. | ||
expecteds (np.ndarray): Expected ground truth values. | ||
""" | ||
raise NotImplementedError | ||
|
||
def result(self): | ||
""" | ||
Compute and return the final metric value. | ||
Returns: | ||
float: Metric value. | ||
""" | ||
raise NotImplementedError |
24 changes: 24 additions & 0 deletions
24
runtime/onert/api/python/package/experimental/train/metrics/registry.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from .categorical_accuracy import CategoricalAccuracy | ||
|
||
|
||
class MetricsRegistry: | ||
""" | ||
Registry for creating metrics by name. | ||
""" | ||
_metrics = { | ||
"categorical_accuracy": CategoricalAccuracy, | ||
} | ||
|
||
@staticmethod | ||
def create_metric(name): | ||
""" | ||
Create a metric instance by name. | ||
Args: | ||
name (str): Name of the metric. | ||
Returns: | ||
BaseMetric: Metric instance. | ||
""" | ||
if name not in MetricsRegistry._metrics: | ||
raise ValueError( | ||
f"Unknown Metric: {name}. Custom metric is not supported yet") | ||
return MetricsRegistry._metrics[name]() |