-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[onert/python] Introduces loss API (#14542)
* [onert/python] Introduces loss API This commit introduces loss API. - cce.py : Categorical CrossEntropy Loss - mse.py : Mean Squared Error Loss ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]> * Remove __call__
- Loading branch information
Showing
4 changed files
with
59 additions
and
0 deletions.
There are no files selected for viewing
5 changes: 5 additions & 0 deletions
5
runtime/onert/api/python/package/experimental/train/losses/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .cce import CategoricalCrossentropy | ||
from .mse import MeanSquaredError | ||
from onert.native.libnnfw_api_pybind import lossinfo | ||
|
||
__all__ = ["CategoricalCrossentropy", "MeanSquaredError", "lossinfo", "loss"] |
15 changes: 15 additions & 0 deletions
15
runtime/onert/api/python/package/experimental/train/losses/cce.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import numpy as np | ||
from .loss import LossFunction | ||
|
||
|
||
class CategoricalCrossentropy(LossFunction): | ||
""" | ||
Categorical Cross-Entropy Loss Function with reduction type. | ||
""" | ||
def __init__(self, reduction="mean"): | ||
""" | ||
Initialize the Categorical Cross-Entropy loss function. | ||
Args: | ||
reduction (str): Reduction type ('mean', 'sum'). | ||
""" | ||
super().__init__(reduction) |
24 changes: 24 additions & 0 deletions
24
runtime/onert/api/python/package/experimental/train/losses/loss.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from onert.native.libnnfw_api_pybind import loss_reduction | ||
|
||
|
||
class LossFunction: | ||
""" | ||
Base class for loss functions with reduction type. | ||
""" | ||
def __init__(self, reduction="mean"): | ||
""" | ||
Initialize the Categorical Cross-Entropy loss function. | ||
Args: | ||
reduction (str): Reduction type ('mean', 'sum'). | ||
""" | ||
reduction_mapping = { | ||
"mean": loss_reduction.SUM_OVER_BATCH_SIZE, | ||
"sum": loss_reduction.SUM | ||
} | ||
|
||
# Validate and assign the reduction type | ||
if reduction not in reduction_mapping: | ||
raise ValueError( | ||
f"Invalid reduction type. Choose from {list(reduction_mapping.keys())}.") | ||
|
||
self.reduction = reduction_mapping[reduction] |
15 changes: 15 additions & 0 deletions
15
runtime/onert/api/python/package/experimental/train/losses/mse.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import numpy as np | ||
from .loss import LossFunction | ||
|
||
|
||
class MeanSquaredError(LossFunction): | ||
""" | ||
Mean Squared Error (MSE) Loss Function with reduction type. | ||
""" | ||
def __init__(self, reduction="mean"): | ||
""" | ||
Initialize the MSE loss function. | ||
Args: | ||
reduction (str): Reduction type ('mean', 'sum'). | ||
""" | ||
super().__init__(reduction) |