Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a centered variance option to the ClippedAdam optimizer #3415

Merged
merged 7 commits into from
Jan 25, 2025
5 changes: 4 additions & 1 deletion examples/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def main(args):
guide = functools.partial(parametrized_guide, predictor)
Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
elbo = Elbo(max_plate_nesting=2)
optim = ClippedAdam({"lr": args.learning_rate})
optim = ClippedAdam(
{"lr": args.learning_rate, "centered_variance": args.centered_variance}
)
svi = SVI(model, guide, optim, elbo)
logging.info("Step\tLoss")
for step in range(args.num_steps):
Expand All @@ -160,6 +162,7 @@ def main(args):
parser.add_argument("-n", "--num-steps", default=1000, type=int)
parser.add_argument("-l", "--layer-sizes", default="100-100")
parser.add_argument("-lr", "--learning-rate", default=0.01, type=float)
parser.add_argument("-cv", "--centered-variance", default=False, type=bool)
parser.add_argument("-b", "--batch-size", default=32, type=int)
parser.add_argument("--jit", action="store_true")
args = parser.parse_args()
Expand Down
20 changes: 15 additions & 5 deletions pyro/optim/clipped_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@ class ClippedAdam(Optimizer):
:param weight_decay: weight decay (L2 penalty) (default: 0)
:param clip_norm: magnitude of norm to which gradients are clipped (default: 10.0)
:param lrd: rate at which learning rate decays (default: 1.0)
:param centered_variance: use centered variance (default: False)

Small modification to the Adam algorithm implemented in torch.optim.Adam
to include gradient clipping and learning rate decay.
to include gradient clipping and learning rate decay and an option to use
the centered variance (see equation 2 in [2]).

Reference
**References**

`A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba
https://arxiv.org/abs/1412.6980
[1] `A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba
https://arxiv.org/abs/1412.6980

[2] `A Two-Step Machine Learning Method for Predicting the Formation Energy of Ternary Compounds`,
Varadarajan Rengaraj, Sebastian Jost, Franz Bethke, Christian Plessl,
Hossein Mirhosseini, Andrea Walther, Thomas D. Kühne
https://doi.org/10.3390/computation11050095
"""

def __init__(
Expand All @@ -38,6 +45,7 @@ def __init__(
weight_decay=0,
clip_norm: float = 10.0,
lrd: float = 1.0,
centered_variance: bool = False,
):
defaults = dict(
lr=lr,
Expand All @@ -46,6 +54,7 @@ def __init__(
weight_decay=weight_decay,
clip_norm=clip_norm,
lrd=lrd,
centered_variance=centered_variance,
)
super().__init__(params, defaults)

Expand Down Expand Up @@ -87,7 +96,8 @@ def step(self, closure: Optional[Callable] = None) -> Optional[Any]:

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
grad_var = (grad - exp_avg) if group["centered_variance"] else grad
BenZickel marked this conversation as resolved.
Show resolved Hide resolved
exp_avg_sq.mul_(beta2).addcmul_(grad_var, grad_var, value=1 - beta2)

denom = exp_avg_sq.sqrt().add_(group["eps"])

Expand Down
10 changes: 10 additions & 0 deletions tests/optim/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ def pytest_collection_modifyitems(items):
item.add_marker(pytest.mark.stage("unit"))
if "init" not in item.keywords:
item.add_marker(pytest.mark.init(rng_seed=123))


def pytest_addoption(parser):
parser.addoption("--plot", action="store", default="FALSE")


def pytest_generate_tests(metafunc):
option_value = metafunc.config.option.plot != "FALSE"
if "plot" in metafunc.fixturenames and option_value is not None:
metafunc.parametrize("plot", [option_value])
123 changes: 123 additions & 0 deletions tests/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,126 @@ def step(svi, optimizer):
actual.append(step(svi, optimizer))

assert_equal(actual, expected)


def test_centered_clipped_adam(plot):
"""
Test the centered variance option of the ClippedAdam optimizer.
In order to create plots run pytest with the plot command line
option set to True, i.e. by executing

'pytest tests/optim/test_optim.py::test_centered_clipped_adam --plot True'

"""
if not plot:
lr_vec = [0.1, 0.001]
else:
lr_vec = [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]

w = torch.Tensor([1, 500])

def loss_fn(p):
return (1 + w * p * p).sqrt().sum() - len(w)

def fit(lr, centered_variance, num_iter=5000):
loss_vec = []
p = torch.nn.Parameter(torch.Tensor([10, 1]))
optim = pyro.optim.clipped_adam.ClippedAdam(
lr=lr, params=[p], centered_variance=centered_variance
)
for count in range(num_iter):
optim.zero_grad()
loss = loss_fn(p)
loss.backward()
optim.step()
loss_vec.append(loss)
return torch.Tensor(loss_vec)

def calc_convergence(loss_vec, tail_len=100, threshold=0.01):
BenZickel marked this conversation as resolved.
Show resolved Hide resolved
"""
Calculate the number of iterations needed in order to reach the
ultimate loss plus a small threshold, and the convergence rate
which is the mean per iteration improvement of the gap between
the loss and the ultimate loss.
"""
ultimate_loss = loss_vec[-tail_len:].mean()
convergence_iter = (loss_vec < (ultimate_loss + threshold)).nonzero().min()
convergence_vec = loss_vec[:convergence_iter] - ultimate_loss
convergence_rate = (convergence_vec[:-1] / convergence_vec[1:]).log().mean()
return ultimate_loss, convergence_rate, convergence_iter

def get_convergence_vec(lr_vec, centered_variance):
BenZickel marked this conversation as resolved.
Show resolved Hide resolved
"""
Fit parameters for a vector of learning rates, with or without centered variance,
and calculate the convergence properties for each learning rate.
"""
ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = [], [], []
for lr in lr_vec:
loss_vec = fit(lr=lr, centered_variance=centered_variance)
ultimate_loss, convergence_rate, convergence_iter = calc_convergence(
loss_vec
)
ultimate_loss_vec.append(ultimate_loss)
convergence_rate_vec.append(convergence_rate)
convergence_iter_vec.append(convergence_iter)
return (
torch.Tensor(ultimate_loss_vec),
torch.Tensor(convergence_rate_vec),
convergence_iter_vec,
)

(
centered_ultimate_loss_vec,
centered_convergence_rate_vec,
centered_convergence_iter_vec,
) = get_convergence_vec(lr_vec=lr_vec, centered_variance=True)
ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = get_convergence_vec(
lr_vec=lr_vec, centered_variance=False
)

# ALl centered variance results should converge
assert (centered_ultimate_loss_vec < 0.01).all()
# Some uncentered variance results do not converge
assert (ultimate_loss_vec > 0.01).any()
# Verify convergence rate improvement
assert (
(centered_convergence_rate_vec / convergence_rate_vec)
> ((0.12 / torch.Tensor(lr_vec)).log() * 1.08)
).all()

if plot:
from matplotlib import pyplot as plt

plt.figure(figsize=(6, 8))
plt.subplot(3, 1, 1)
plt.loglog(
lr_vec, centered_convergence_iter_vec, "b.-", label="Centered Variance"
)
plt.loglog(lr_vec, convergence_iter_vec, "r.-", label="Uncentered Variance")
plt.xlabel("Learning Rate")
plt.ylabel("Convergence Iteration")
plt.title("Convergence Iteration vs Learning Rate")
plt.grid()
plt.legend(loc="best")
plt.subplot(3, 1, 2)
plt.loglog(
lr_vec, centered_convergence_rate_vec, "b.-", label="Centered Variance"
)
plt.loglog(lr_vec, convergence_rate_vec, "r.-", label="Uncentered Variance")
plt.xlabel("Learning Rate")
plt.ylabel("Convergence Rate")
plt.title("Convergence Rate vs Learning Rate")
plt.grid()
plt.legend(loc="best")
plt.subplot(3, 1, 3)
plt.semilogx(
lr_vec, centered_ultimate_loss_vec, "b.-", label="Centered Variance"
)
plt.semilogx(lr_vec, ultimate_loss_vec, "r.-", label="Uncentered Variance")
plt.xlabel("Learning Rate")
plt.ylabel("Ultimate Loss")
plt.title("Ultimate Loss vs Learning Rate")
plt.grid()
plt.legend(loc="best")
plt.tight_layout()
plt.savefig("test_centered_variance.png")