Skip to content

Commit

Permalink
[onert/python] Introduce optimizer API (#14543)
Browse files Browse the repository at this point in the history
This commit introduces optimizer API.
  - adam.py : Adam optimizer class
  - sgd.py  : SGB optimizer class

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Jan 20, 2025
1 parent e514715 commit d49a507
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .sgd import SGD
from .adam import Adam
from onert.native.libnnfw_api_pybind import trainable_ops

__all__ = ["SGD", "Adam", "trainable_ops"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .optimizer import Optimizer


class Adam(Optimizer):
"""
Adam optimizer.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-7):
"""
Initialize the Adam optimizer.
Args:
learning_rate (float): The learning rate for optimization.
beta1 (float): Exponential decay rate for the first moment estimates.
beta2 (float): Exponential decay rate for the second moment estimates.
epsilon (float): Small constant to prevent division by zero.
"""
super().__init__(learning_rate)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from onert.native.libnnfw_api_pybind import trainable_ops


class Optimizer:
"""
Base class for optimizers.
"""
def __init__(self, learning_rate=0.001, nums_trainable_ops=trainable_ops.ALL):
"""
Initialize the optimizer.
Args:
learning_rate (float): The learning rate for optimization.
"""
self.learning_rate = learning_rate
self.nums_trainable_ops = nums_trainable_ops
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .optimizer import Optimizer


class SGD(Optimizer):
"""
Stochastic Gradient Descent (SGD) optimizer.
"""
def __init__(self, learning_rate=0.001, momentum=0.0):
"""
Initialize the SGD optimizer.
Args:
learning_rate (float): The learning rate for optimization.
momentum (float): Momentum factor (default: 0.0).
"""
super().__init__(learning_rate)

if momentum != 0.0:
raise NotImplementedError(
"Momentum is not supported in the current version of SGD.")
self.momentum = momentum
self.velocity = None

0 comments on commit d49a507

Please sign in to comment.