Skip to content

Commit

Permalink
Remove __call__
Browse files Browse the repository at this point in the history
  • Loading branch information
ragmani committed Jan 14, 2025
1 parent a2d9e36 commit 92540e9
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 33 deletions.
18 changes: 0 additions & 18 deletions runtime/onert/api/python/package/experimental/train/losses/cce.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,3 @@ def __init__(self, reduction="mean"):
reduction (str): Reduction type ('mean', 'sum').
"""
super().__init__(reduction)

def __call__(self, y_true, y_pred):
"""
Compute the Categorical Cross-Entropy loss.
Args:
y_true (np.ndarray): One-hot encoded ground truth values.
y_pred (np.ndarray): Predicted probabilities.
Returns:
float or np.ndarray: Computed loss value(s).
"""
epsilon = 1e-7 # Prevent log(0)
y_pred = np.clip(y_pred, epsilon, 1 - epsilon)
loss = -np.sum(y_true * np.log(y_pred), axis=1)

if self.reduction == "mean":
return np.mean(loss)
elif self.reduction == "sum":
return np.sum(loss)
15 changes: 0 additions & 15 deletions runtime/onert/api/python/package/experimental/train/losses/mse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,3 @@ def __init__(self, reduction="mean"):
reduction (str): Reduction type ('mean', 'sum').
"""
super().__init__(reduction)

def __call__(self, y_true, y_pred):
"""
Compute the Mean Squared Error (MSE) loss.
Args:
y_true (np.ndarray): Ground truth values.
y_pred (np.ndarray): Predicted values.
Returns:
float or np.ndarray: Computed MSE loss value(s).
"""
loss = (y_true - y_pred)**2
if self.reduction == "mean":
return np.mean(loss)
elif self.reduction == "sum":
return np.sum(loss)

0 comments on commit 92540e9

Please sign in to comment.