Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[onert/python] Introduce optimizer API #14543

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .sgd import SGD
from .adam import Adam
from onert.native.libnnfw_api_pybind import trainable_ops

__all__ = ["SGD", "Adam", "trainable_ops"]
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from .optimizer import Optimizer


class Adam(Optimizer):
"""
Adam optimizer.
"""
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-7):
"""
Initialize the Adam optimizer.
Args:
learning_rate (float): The learning rate for optimization.
beta1 (float): Exponential decay rate for the first moment estimates.
beta2 (float): Exponential decay rate for the second moment estimates.
epsilon (float): Small constant to prevent division by zero.
"""
super().__init__(learning_rate)
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
from onert.native.libnnfw_api_pybind import trainable_ops


class Optimizer:
"""
Base class for optimizers. Subclasses should implement the `step` method.
"""
Comment on lines +5 to +7
Copy link
Contributor

@zetwhite zetwhite Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"""
Base class for optimizers. Subclasses should implement the `step` method.
"""
"""
Base class for optimizers.
"""

Since step() is deleted, this needs to be updated.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ragmani But I think this has to be fixed. That's why I didn't approve this PR yet.

def __init__(self, learning_rate=0.001, nums_trainable_ops=trainable_ops.ALL):
"""
Initialize the optimizer.
Args:
learning_rate (float): The learning rate for optimization.
"""
self.learning_rate = learning_rate
self.nums_trainable_ops = nums_trainable_ops
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from .optimizer import Optimizer


class SGD(Optimizer):
"""
Stochastic Gradient Descent (SGD) optimizer.
"""
def __init__(self, learning_rate=0.001, momentum=0.0):
"""
Initialize the SGD optimizer.
Args:
learning_rate (float): The learning rate for optimization.
momentum (float): Momentum factor (default: 0.0).
"""
super().__init__(learning_rate)

if momentum != 0.0:
raise NotImplementedError(
"Momentum is not supported in the current version of SGD.")
self.momentum = momentum
self.velocity = None
zetwhite marked this conversation as resolved.
Show resolved Hide resolved