Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a centered variance option to the ClippedAdam optimizer #3415

Merged
merged 7 commits into from
Jan 25, 2025
6 changes: 5 additions & 1 deletion pyro/optim/clipped_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ClippedAdam(Optimizer):
:param weight_decay: weight decay (L2 penalty) (default: 0)
:param clip_norm: magnitude of norm to which gradients are clipped (default: 10.0)
:param lrd: rate at which learning rate decays (default: 1.0)
:param centered_variance: use centered variance (default: False)

Small modification to the Adam algorithm implemented in torch.optim.Adam
to include gradient clipping and learning rate decay.
Expand All @@ -38,6 +39,7 @@ def __init__(
weight_decay=0,
clip_norm: float = 10.0,
lrd: float = 1.0,
centered_variance: bool = False,
):
defaults = dict(
lr=lr,
Expand All @@ -46,6 +48,7 @@ def __init__(
weight_decay=weight_decay,
clip_norm=clip_norm,
lrd=lrd,
centered_variance=centered_variance,
)
super().__init__(params, defaults)

Expand Down Expand Up @@ -87,7 +90,8 @@ def step(self, closure: Optional[Callable] = None) -> Optional[Any]:

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
grad_var = (grad - exp_avg) if group["centered_variance"] else grad
BenZickel marked this conversation as resolved.
Show resolved Hide resolved
exp_avg_sq.mul_(beta2).addcmul_(grad_var, grad_var, value=1 - beta2)

denom = exp_avg_sq.sqrt().add_(group["eps"])

Expand Down
83 changes: 83 additions & 0 deletions tests/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,86 @@ def step(svi, optimizer):
actual.append(step(svi, optimizer))

assert_equal(actual, expected)


def test_centered_clipped_adam(plot_results=False):
BenZickel marked this conversation as resolved.
Show resolved Hide resolved
w = torch.Tensor([1, 500])

def loss_fn(p):
return (1 + w * p * p).sqrt().sum() - len(w)

def fit(lr, centered_variance, num_iter=5000):
loss_vec = []
p = torch.nn.Parameter(torch.Tensor([10, 1]))
optim = pyro.optim.clipped_adam.ClippedAdam(
lr=lr, params=[p], centered_variance=centered_variance
)
for count in range(num_iter):
optim.zero_grad()
loss = loss_fn(p)
loss.backward()
optim.step()
loss_vec.append(loss)
return torch.Tensor(loss_vec)

def calc_convergence(loss_vec, tail_len=100, threshold=0.01):
BenZickel marked this conversation as resolved.
Show resolved Hide resolved
ultimate_loss = loss_vec[-tail_len:].mean()
idx = (loss_vec < (ultimate_loss + threshold)).nonzero().min()
convergence_vec = loss_vec[:idx] - ultimate_loss
convergence_rate = (convergence_vec[:-1] / convergence_vec[1:]).log().mean()
return ultimate_loss, convergence_rate

def get_convergence_vec(lr_vec, centered_variance):
BenZickel marked this conversation as resolved.
Show resolved Hide resolved
ultimate_loss_vec, convergence_rate_vec = [], []
for lr in lr_vec:
loss_vec = fit(lr=lr, centered_variance=centered_variance)
ultimate_loss, convergence_rate = calc_convergence(loss_vec)
ultimate_loss_vec.append(ultimate_loss)
convergence_rate_vec.append(convergence_rate)
print(lr, centered_variance, ultimate_loss, convergence_rate)
return torch.Tensor(ultimate_loss_vec), torch.Tensor(convergence_rate_vec)

lr_vec = [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]
centered_ultimate_loss_vec, centered_convergence_rate_vec = get_convergence_vec(
lr_vec=lr_vec, centered_variance=True
)
ultimate_loss_vec, convergence_rate_vec = get_convergence_vec(
lr_vec=lr_vec, centered_variance=False
)

# ALl centered variance results should converge
assert (centered_ultimate_loss_vec < 0.01).all()
# Some uncentered variance results do not converge
assert (ultimate_loss_vec > 0.01).any()
# Verify convergence rate improvement
assert (
(centered_convergence_rate_vec / convergence_rate_vec)
> (torch.Tensor([1.2] * len(lr_vec)).cumprod(0))
).all()

if plot_results:
from matplotlib import pyplot as plt

plt.figure()
plt.subplot(2, 1, 1)
plt.loglog(
lr_vec, centered_convergence_rate_vec, "b.-", label="Centered Variance"
)
plt.loglog(lr_vec, convergence_rate_vec, "r.-", label="Uncentered Variance")
plt.xlabel("Learning Rate")
plt.ylabel("Convergence Rate")
plt.title("Convergence Rate vs Learning Rate")
plt.grid()
plt.legend(loc="best")
plt.subplot(2, 1, 2)
plt.semilogx(
lr_vec, centered_ultimate_loss_vec, "b.-", label="Centered Variance"
)
plt.semilogx(lr_vec, ultimate_loss_vec, "r.-", label="Uncentered Variance")
plt.xlabel("Learning Rate")
plt.ylabel("Ultimate Loss")
plt.title("Ultimate Loss vs Learning Rate")
plt.grid()
plt.legend(loc="best")
plt.tight_layout()
plt.savefig("test_centered_variance.png")