From 237f86495eb03a48b5e8c39c195c5f563ad81524 Mon Sep 17 00:00:00 2001 From: Jang Jiseob Date: Wed, 15 Jan 2025 13:50:36 +0900 Subject: [PATCH] [onert/python] Introduces loss API (#14542) * [onert/python] Introduces loss API This commit introduces loss API. - cce.py : Categorical CrossEntropy Loss - mse.py : Mean Squared Error Loss ONE-DCO-1.0-Signed-off-by: ragmani * Remove __call__ --- .../experimental/train/losses/__init__.py | 5 ++++ .../package/experimental/train/losses/cce.py | 15 ++++++++++++ .../package/experimental/train/losses/loss.py | 24 +++++++++++++++++++ .../package/experimental/train/losses/mse.py | 15 ++++++++++++ 4 files changed, 59 insertions(+) create mode 100644 runtime/onert/api/python/package/experimental/train/losses/__init__.py create mode 100644 runtime/onert/api/python/package/experimental/train/losses/cce.py create mode 100644 runtime/onert/api/python/package/experimental/train/losses/loss.py create mode 100644 runtime/onert/api/python/package/experimental/train/losses/mse.py diff --git a/runtime/onert/api/python/package/experimental/train/losses/__init__.py b/runtime/onert/api/python/package/experimental/train/losses/__init__.py new file mode 100644 index 00000000000..12977444839 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/losses/__init__.py @@ -0,0 +1,5 @@ +from .cce import CategoricalCrossentropy +from .mse import MeanSquaredError +from onert.native.libnnfw_api_pybind import lossinfo + +__all__ = ["CategoricalCrossentropy", "MeanSquaredError", "lossinfo", "loss"] diff --git a/runtime/onert/api/python/package/experimental/train/losses/cce.py b/runtime/onert/api/python/package/experimental/train/losses/cce.py new file mode 100644 index 00000000000..0fb1fa89729 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/losses/cce.py @@ -0,0 +1,15 @@ +import numpy as np +from .loss import LossFunction + + +class CategoricalCrossentropy(LossFunction): + """ + Categorical Cross-Entropy Loss Function with reduction type. + """ + def __init__(self, reduction="mean"): + """ + Initialize the Categorical Cross-Entropy loss function. + Args: + reduction (str): Reduction type ('mean', 'sum'). + """ + super().__init__(reduction) diff --git a/runtime/onert/api/python/package/experimental/train/losses/loss.py b/runtime/onert/api/python/package/experimental/train/losses/loss.py new file mode 100644 index 00000000000..0719ba68bfa --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/losses/loss.py @@ -0,0 +1,24 @@ +from onert.native.libnnfw_api_pybind import loss_reduction + + +class LossFunction: + """ + Base class for loss functions with reduction type. + """ + def __init__(self, reduction="mean"): + """ + Initialize the Categorical Cross-Entropy loss function. + Args: + reduction (str): Reduction type ('mean', 'sum'). + """ + reduction_mapping = { + "mean": loss_reduction.SUM_OVER_BATCH_SIZE, + "sum": loss_reduction.SUM + } + + # Validate and assign the reduction type + if reduction not in reduction_mapping: + raise ValueError( + f"Invalid reduction type. Choose from {list(reduction_mapping.keys())}.") + + self.reduction = reduction_mapping[reduction] diff --git a/runtime/onert/api/python/package/experimental/train/losses/mse.py b/runtime/onert/api/python/package/experimental/train/losses/mse.py new file mode 100644 index 00000000000..ed9e37e3482 --- /dev/null +++ b/runtime/onert/api/python/package/experimental/train/losses/mse.py @@ -0,0 +1,15 @@ +import numpy as np +from .loss import LossFunction + + +class MeanSquaredError(LossFunction): + """ + Mean Squared Error (MSE) Loss Function with reduction type. + """ + def __init__(self, reduction="mean"): + """ + Initialize the MSE loss function. + Args: + reduction (str): Reduction type ('mean', 'sum'). + """ + super().__init__(reduction)