-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[onert/python] Introduce optimizer API (#14543)
This commit introduces optimizer API. - adam.py : Adam optimizer class - sgd.py : SGB optimizer class ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
- Loading branch information
Showing
4 changed files
with
61 additions
and
0 deletions.
There are no files selected for viewing
5 changes: 5 additions & 0 deletions
5
runtime/onert/api/python/package/experimental/train/optimizer/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .sgd import SGD | ||
from .adam import Adam | ||
from onert.native.libnnfw_api_pybind import trainable_ops | ||
|
||
__all__ = ["SGD", "Adam", "trainable_ops"] |
20 changes: 20 additions & 0 deletions
20
runtime/onert/api/python/package/experimental/train/optimizer/adam.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
from .optimizer import Optimizer | ||
|
||
|
||
class Adam(Optimizer): | ||
""" | ||
Adam optimizer. | ||
""" | ||
def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-7): | ||
""" | ||
Initialize the Adam optimizer. | ||
Args: | ||
learning_rate (float): The learning rate for optimization. | ||
beta1 (float): Exponential decay rate for the first moment estimates. | ||
beta2 (float): Exponential decay rate for the second moment estimates. | ||
epsilon (float): Small constant to prevent division by zero. | ||
""" | ||
super().__init__(learning_rate) | ||
self.beta1 = beta1 | ||
self.beta2 = beta2 | ||
self.epsilon = epsilon |
15 changes: 15 additions & 0 deletions
15
runtime/onert/api/python/package/experimental/train/optimizer/optimizer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
from onert.native.libnnfw_api_pybind import trainable_ops | ||
|
||
|
||
class Optimizer: | ||
""" | ||
Base class for optimizers. | ||
""" | ||
def __init__(self, learning_rate=0.001, nums_trainable_ops=trainable_ops.ALL): | ||
""" | ||
Initialize the optimizer. | ||
Args: | ||
learning_rate (float): The learning rate for optimization. | ||
""" | ||
self.learning_rate = learning_rate | ||
self.nums_trainable_ops = nums_trainable_ops |
21 changes: 21 additions & 0 deletions
21
runtime/onert/api/python/package/experimental/train/optimizer/sgd.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from .optimizer import Optimizer | ||
|
||
|
||
class SGD(Optimizer): | ||
""" | ||
Stochastic Gradient Descent (SGD) optimizer. | ||
""" | ||
def __init__(self, learning_rate=0.001, momentum=0.0): | ||
""" | ||
Initialize the SGD optimizer. | ||
Args: | ||
learning_rate (float): The learning rate for optimization. | ||
momentum (float): Momentum factor (default: 0.0). | ||
""" | ||
super().__init__(learning_rate) | ||
|
||
if momentum != 0.0: | ||
raise NotImplementedError( | ||
"Momentum is not supported in the current version of SGD.") | ||
self.momentum = momentum | ||
self.velocity = None |