Skip to content

Commit

Permalink
Remove step function
Browse files Browse the repository at this point in the history
  • Loading branch information
ragmani committed Jan 15, 2025
1 parent 77b096c commit 69fef07
Show file tree
Hide file tree
Showing 3 changed files with 0 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,3 @@ def __init__(self, learning_rate=0.001, beta1=0.9, beta2=0.999, epsilon=1e-7):
self.beta1 = beta1
self.beta2 = beta2
self.epsilon = epsilon
self.m = None
self.v = None
self.t = 0

def step(self, gradients, parameters):
"""
Update parameters using Adam optimization.
Args:
gradients (list): List of gradients for each parameter.
parameters (list): List of parameters to be updated.
"""
if self.m is None:
self.m = [0] * len(parameters)
if self.v is None:
self.v = [0] * len(parameters)

self.t += 1
for i, (grad, param) in enumerate(zip(gradients, parameters)):
self.m[i] = self.beta1 * self.m[i] + (1 - self.beta1) * grad
self.v[i] = self.beta2 * self.v[i] + (1 - self.beta2) * (grad**2)

m_hat = self.m[i] / (1 - self.beta1**self.t)
v_hat = self.v[i] / (1 - self.beta2**self.t)

param -= self.learning_rate * m_hat / (v_hat**0.5 + self.epsilon)
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,3 @@ def __init__(self, learning_rate=0.001, nums_trainable_ops=trainable_ops.ALL):
"""
self.learning_rate = learning_rate
self.nums_trainable_ops = nums_trainable_ops

def step(self, gradients, parameters):
"""
Update parameters based on gradients. Should be implemented by subclasses.
Args:
gradients (list): List of gradients for each parameter.
parameters (list): List of parameters to be updated.
"""
raise NotImplementedError("Subclasses must implement the `step` method.")
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,3 @@ def __init__(self, learning_rate=0.001, momentum=0.0):
"Momentum is not supported in the current version of SGD.")
self.momentum = momentum
self.velocity = None

def step(self, gradients, parameters):
"""
Update parameters using SGD with optional momentum.
Args:
gradients (list): List of gradients for each parameter.
parameters (list): List of parameters to be updated.
"""
if self.velocity is None:
self.velocity = [0] * len(parameters)

for i, (grad, param) in enumerate(zip(gradients, parameters)):
self.velocity[
i] = self.momentum * self.velocity[i] - self.learning_rate * grad
parameters[i] += self.velocity[i]

0 comments on commit 69fef07

Please sign in to comment.