Skip to content

Commit

Permalink
[onert/python] Introduces loss API (#14542)
Browse files Browse the repository at this point in the history
* [onert/python] Introduces loss API

This commit introduces loss API.
  - cce.py : Categorical CrossEntropy Loss
  - mse.py : Mean Squared Error Loss

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>

* Remove __call__
  • Loading branch information
ragmani authored Jan 15, 2025
1 parent b7d4e5e commit 237f864
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .cce import CategoricalCrossentropy
from .mse import MeanSquaredError
from onert.native.libnnfw_api_pybind import lossinfo

__all__ = ["CategoricalCrossentropy", "MeanSquaredError", "lossinfo", "loss"]
15 changes: 15 additions & 0 deletions runtime/onert/api/python/package/experimental/train/losses/cce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np
from .loss import LossFunction


class CategoricalCrossentropy(LossFunction):
"""
Categorical Cross-Entropy Loss Function with reduction type.
"""
def __init__(self, reduction="mean"):
"""
Initialize the Categorical Cross-Entropy loss function.
Args:
reduction (str): Reduction type ('mean', 'sum').
"""
super().__init__(reduction)
24 changes: 24 additions & 0 deletions runtime/onert/api/python/package/experimental/train/losses/loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from onert.native.libnnfw_api_pybind import loss_reduction


class LossFunction:
"""
Base class for loss functions with reduction type.
"""
def __init__(self, reduction="mean"):
"""
Initialize the Categorical Cross-Entropy loss function.
Args:
reduction (str): Reduction type ('mean', 'sum').
"""
reduction_mapping = {
"mean": loss_reduction.SUM_OVER_BATCH_SIZE,
"sum": loss_reduction.SUM
}

# Validate and assign the reduction type
if reduction not in reduction_mapping:
raise ValueError(
f"Invalid reduction type. Choose from {list(reduction_mapping.keys())}.")

self.reduction = reduction_mapping[reduction]
15 changes: 15 additions & 0 deletions runtime/onert/api/python/package/experimental/train/losses/mse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import numpy as np
from .loss import LossFunction


class MeanSquaredError(LossFunction):
"""
Mean Squared Error (MSE) Loss Function with reduction type.
"""
def __init__(self, reduction="mean"):
"""
Initialize the MSE loss function.
Args:
reduction (str): Reduction type ('mean', 'sum').
"""
super().__init__(reduction)

0 comments on commit 237f864

Please sign in to comment.