From 87a24f6684bf6076bd2b7c257297c472e3ae5969 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Tue, 21 Jan 2025 09:29:52 +0200 Subject: [PATCH 1/7] Add option to use centered variance in the ClippedAdam optimizer. --- pyro/optim/clipped_adam.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/pyro/optim/clipped_adam.py b/pyro/optim/clipped_adam.py index 14a6a06656..cd655ef196 100644 --- a/pyro/optim/clipped_adam.py +++ b/pyro/optim/clipped_adam.py @@ -19,6 +19,7 @@ class ClippedAdam(Optimizer): :param weight_decay: weight decay (L2 penalty) (default: 0) :param clip_norm: magnitude of norm to which gradients are clipped (default: 10.0) :param lrd: rate at which learning rate decays (default: 1.0) + :param centered_variance: use centered variance (default: False) Small modification to the Adam algorithm implemented in torch.optim.Adam to include gradient clipping and learning rate decay. @@ -38,6 +39,7 @@ def __init__( weight_decay=0, clip_norm: float = 10.0, lrd: float = 1.0, + centered_variance: bool = False, ): defaults = dict( lr=lr, @@ -46,6 +48,7 @@ def __init__( weight_decay=weight_decay, clip_norm=clip_norm, lrd=lrd, + centered_variance=centered_variance, ) super().__init__(params, defaults) @@ -87,7 +90,8 @@ def step(self, closure: Optional[Callable] = None) -> Optional[Any]: # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + grad_var = (grad - exp_avg) if group["centered_variance"] else grad + exp_avg_sq.mul_(beta2).addcmul_(grad_var, grad_var, value=1 - beta2) denom = exp_avg_sq.sqrt().add_(group["eps"]) From 3d13bda66625daad733e05ef39e6f3ee30ce6bb1 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Tue, 21 Jan 2025 22:49:59 +0200 Subject: [PATCH 2/7] Add test for the centered ClippedAdam optimizer. --- tests/optim/test_optim.py | 83 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 83 insertions(+) diff --git a/tests/optim/test_optim.py b/tests/optim/test_optim.py index 6b6dc59d8a..7ca88ebe79 100644 --- a/tests/optim/test_optim.py +++ b/tests/optim/test_optim.py @@ -435,3 +435,86 @@ def step(svi, optimizer): actual.append(step(svi, optimizer)) assert_equal(actual, expected) + + +def test_centered_clipped_adam(plot_results=False): + w = torch.Tensor([1, 500]) + + def loss_fn(p): + return (1 + w * p * p).sqrt().sum() - len(w) + + def fit(lr, centered_variance, num_iter=5000): + loss_vec = [] + p = torch.nn.Parameter(torch.Tensor([10, 1])) + optim = pyro.optim.clipped_adam.ClippedAdam( + lr=lr, params=[p], centered_variance=centered_variance + ) + for count in range(num_iter): + optim.zero_grad() + loss = loss_fn(p) + loss.backward() + optim.step() + loss_vec.append(loss) + return torch.Tensor(loss_vec) + + def calc_convergence(loss_vec, tail_len=100, threshold=0.01): + ultimate_loss = loss_vec[-tail_len:].mean() + idx = (loss_vec < (ultimate_loss + threshold)).nonzero().min() + convergence_vec = loss_vec[:idx] - ultimate_loss + convergence_rate = (convergence_vec[:-1] / convergence_vec[1:]).log().mean() + return ultimate_loss, convergence_rate + + def get_convergence_vec(lr_vec, centered_variance): + ultimate_loss_vec, convergence_rate_vec = [], [] + for lr in lr_vec: + loss_vec = fit(lr=lr, centered_variance=centered_variance) + ultimate_loss, convergence_rate = calc_convergence(loss_vec) + ultimate_loss_vec.append(ultimate_loss) + convergence_rate_vec.append(convergence_rate) + print(lr, centered_variance, ultimate_loss, convergence_rate) + return torch.Tensor(ultimate_loss_vec), torch.Tensor(convergence_rate_vec) + + lr_vec = [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001] + centered_ultimate_loss_vec, centered_convergence_rate_vec = get_convergence_vec( + lr_vec=lr_vec, centered_variance=True + ) + ultimate_loss_vec, convergence_rate_vec = get_convergence_vec( + lr_vec=lr_vec, centered_variance=False + ) + + # ALl centered variance results should converge + assert (centered_ultimate_loss_vec < 0.01).all() + # Some uncentered variance results do not converge + assert (ultimate_loss_vec > 0.01).any() + # Verify convergence rate improvement + assert ( + (centered_convergence_rate_vec / convergence_rate_vec) + > (torch.Tensor([1.2] * len(lr_vec)).cumprod(0)) + ).all() + + if plot_results: + from matplotlib import pyplot as plt + + plt.figure() + plt.subplot(2, 1, 1) + plt.loglog( + lr_vec, centered_convergence_rate_vec, "b.-", label="Centered Variance" + ) + plt.loglog(lr_vec, convergence_rate_vec, "r.-", label="Uncentered Variance") + plt.xlabel("Learning Rate") + plt.ylabel("Convergence Rate") + plt.title("Convergence Rate vs Learning Rate") + plt.grid() + plt.legend(loc="best") + plt.subplot(2, 1, 2) + plt.semilogx( + lr_vec, centered_ultimate_loss_vec, "b.-", label="Centered Variance" + ) + plt.semilogx(lr_vec, ultimate_loss_vec, "r.-", label="Uncentered Variance") + plt.xlabel("Learning Rate") + plt.ylabel("Ultimate Loss") + plt.title("Ultimate Loss vs Learning Rate") + plt.grid() + plt.legend(loc="best") + plt.tight_layout() + plt.savefig("test_centered_variance.png") From 51e45a66c1b1fa42257f0fa3cad2fa7ba03b76dc Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 24 Jan 2025 10:34:17 +0200 Subject: [PATCH 3/7] Calculate convergence iteration for the centered ClippedAdam optimizer. --- tests/optim/test_optim.py | 45 ++++++++++++++++++++++++++++----------- 1 file changed, 32 insertions(+), 13 deletions(-) diff --git a/tests/optim/test_optim.py b/tests/optim/test_optim.py index 7ca88ebe79..91430420aa 100644 --- a/tests/optim/test_optim.py +++ b/tests/optim/test_optim.py @@ -459,26 +459,35 @@ def fit(lr, centered_variance, num_iter=5000): def calc_convergence(loss_vec, tail_len=100, threshold=0.01): ultimate_loss = loss_vec[-tail_len:].mean() - idx = (loss_vec < (ultimate_loss + threshold)).nonzero().min() - convergence_vec = loss_vec[:idx] - ultimate_loss + convergence_iter = (loss_vec < (ultimate_loss + threshold)).nonzero().min() + convergence_vec = loss_vec[:convergence_iter] - ultimate_loss convergence_rate = (convergence_vec[:-1] / convergence_vec[1:]).log().mean() - return ultimate_loss, convergence_rate + return ultimate_loss, convergence_rate, convergence_iter def get_convergence_vec(lr_vec, centered_variance): - ultimate_loss_vec, convergence_rate_vec = [], [] + ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = [], [], [] for lr in lr_vec: loss_vec = fit(lr=lr, centered_variance=centered_variance) - ultimate_loss, convergence_rate = calc_convergence(loss_vec) + ultimate_loss, convergence_rate, convergence_iter = calc_convergence( + loss_vec + ) ultimate_loss_vec.append(ultimate_loss) convergence_rate_vec.append(convergence_rate) + convergence_iter_vec.append(convergence_iter) print(lr, centered_variance, ultimate_loss, convergence_rate) - return torch.Tensor(ultimate_loss_vec), torch.Tensor(convergence_rate_vec) + return ( + torch.Tensor(ultimate_loss_vec), + torch.Tensor(convergence_rate_vec), + convergence_iter_vec, + ) lr_vec = [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001] - centered_ultimate_loss_vec, centered_convergence_rate_vec = get_convergence_vec( - lr_vec=lr_vec, centered_variance=True - ) - ultimate_loss_vec, convergence_rate_vec = get_convergence_vec( + ( + centered_ultimate_loss_vec, + centered_convergence_rate_vec, + centered_convergence_iter_vec, + ) = get_convergence_vec(lr_vec=lr_vec, centered_variance=True) + ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = get_convergence_vec( lr_vec=lr_vec, centered_variance=False ) @@ -495,8 +504,18 @@ def get_convergence_vec(lr_vec, centered_variance): if plot_results: from matplotlib import pyplot as plt - plt.figure() - plt.subplot(2, 1, 1) + plt.figure(figsize=(6, 8)) + plt.subplot(3, 1, 1) + plt.loglog( + lr_vec, centered_convergence_iter_vec, "b.-", label="Centered Variance" + ) + plt.loglog(lr_vec, convergence_iter_vec, "r.-", label="Uncentered Variance") + plt.xlabel("Learning Rate") + plt.ylabel("Convergence Iteration") + plt.title("Convergence Iteration vs Learning Rate") + plt.grid() + plt.legend(loc="best") + plt.subplot(3, 1, 2) plt.loglog( lr_vec, centered_convergence_rate_vec, "b.-", label="Centered Variance" ) @@ -506,7 +525,7 @@ def get_convergence_vec(lr_vec, centered_variance): plt.title("Convergence Rate vs Learning Rate") plt.grid() plt.legend(loc="best") - plt.subplot(2, 1, 2) + plt.subplot(3, 1, 3) plt.semilogx( lr_vec, centered_ultimate_loss_vec, "b.-", label="Centered Variance" ) From 9247336b38df19928a5bfebb1bd036ca87bd45f1 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 24 Jan 2025 11:01:32 +0200 Subject: [PATCH 4/7] Added reference of the centered Adam optimizer. --- pyro/optim/clipped_adam.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pyro/optim/clipped_adam.py b/pyro/optim/clipped_adam.py index cd655ef196..12779d4f8b 100644 --- a/pyro/optim/clipped_adam.py +++ b/pyro/optim/clipped_adam.py @@ -22,12 +22,18 @@ class ClippedAdam(Optimizer): :param centered_variance: use centered variance (default: False) Small modification to the Adam algorithm implemented in torch.optim.Adam - to include gradient clipping and learning rate decay. + to include gradient clipping and learning rate decay and an option to use + the centered variance. - Reference + References `A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba https://arxiv.org/abs/1412.6980 + + `A Two-Step Machine Learning Method for Predicting the Formation Energy of Ternary Compounds`, + Varadarajan Rengaraj, Sebastian Jost, Franz Bethke, Christian Plessl, + Hossein Mirhosseini, Andrea Walther, Thomas D. Kühne + https://doi.org/10.3390/computation11050095 """ def __init__( From b51a14a557dd64cfc146f64a629b10d71ee4a504 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 24 Jan 2025 13:37:16 +0200 Subject: [PATCH 5/7] Add option to use the ClippedAdam optimizer with centered variance in the Latent Dirichlet Allocation example. --- examples/lda.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/lda.py b/examples/lda.py index 16fc09ad0b..00d3ac3bef 100644 --- a/examples/lda.py +++ b/examples/lda.py @@ -137,7 +137,9 @@ def main(args): guide = functools.partial(parametrized_guide, predictor) Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO elbo = Elbo(max_plate_nesting=2) - optim = ClippedAdam({"lr": args.learning_rate}) + optim = ClippedAdam( + {"lr": args.learning_rate, "centered_variance": args.centered_variance} + ) svi = SVI(model, guide, optim, elbo) logging.info("Step\tLoss") for step in range(args.num_steps): @@ -160,6 +162,7 @@ def main(args): parser.add_argument("-n", "--num-steps", default=1000, type=int) parser.add_argument("-l", "--layer-sizes", default="100-100") parser.add_argument("-lr", "--learning-rate", default=0.01, type=float) + parser.add_argument("-cv", "--centered-variance", default=False, type=bool) parser.add_argument("-b", "--batch-size", default=32, type=int) parser.add_argument("--jit", action="store_true") args = parser.parse_args() From efa56d99ab6562a479047bd16b2e45347ec1a97b Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Fri, 24 Jan 2025 23:56:19 +0200 Subject: [PATCH 6/7] Added more detailed comments on ClippedAdam with centered variance and its tests. --- pyro/optim/clipped_adam.py | 16 ++++++++-------- tests/optim/test_optim.py | 10 ++++++++++ 2 files changed, 18 insertions(+), 8 deletions(-) diff --git a/pyro/optim/clipped_adam.py b/pyro/optim/clipped_adam.py index 12779d4f8b..14ac268129 100644 --- a/pyro/optim/clipped_adam.py +++ b/pyro/optim/clipped_adam.py @@ -23,17 +23,17 @@ class ClippedAdam(Optimizer): Small modification to the Adam algorithm implemented in torch.optim.Adam to include gradient clipping and learning rate decay and an option to use - the centered variance. + the centered variance (see equation 2 in [2]). - References + **References** - `A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba - https://arxiv.org/abs/1412.6980 + [1] `A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba + https://arxiv.org/abs/1412.6980 - `A Two-Step Machine Learning Method for Predicting the Formation Energy of Ternary Compounds`, - Varadarajan Rengaraj, Sebastian Jost, Franz Bethke, Christian Plessl, - Hossein Mirhosseini, Andrea Walther, Thomas D. Kühne - https://doi.org/10.3390/computation11050095 + [2] `A Two-Step Machine Learning Method for Predicting the Formation Energy of Ternary Compounds`, + Varadarajan Rengaraj, Sebastian Jost, Franz Bethke, Christian Plessl, + Hossein Mirhosseini, Andrea Walther, Thomas D. Kühne + https://doi.org/10.3390/computation11050095 """ def __init__( diff --git a/tests/optim/test_optim.py b/tests/optim/test_optim.py index 91430420aa..5e745efddf 100644 --- a/tests/optim/test_optim.py +++ b/tests/optim/test_optim.py @@ -458,6 +458,12 @@ def fit(lr, centered_variance, num_iter=5000): return torch.Tensor(loss_vec) def calc_convergence(loss_vec, tail_len=100, threshold=0.01): + """ + Calculate the number of iterations needed in order to reach the + ultimate loss plus a small threshold, and the convergence rate + which is the mean per iteration improvement of the gap between + the loss and the ultimate loss. + """ ultimate_loss = loss_vec[-tail_len:].mean() convergence_iter = (loss_vec < (ultimate_loss + threshold)).nonzero().min() convergence_vec = loss_vec[:convergence_iter] - ultimate_loss @@ -465,6 +471,10 @@ def calc_convergence(loss_vec, tail_len=100, threshold=0.01): return ultimate_loss, convergence_rate, convergence_iter def get_convergence_vec(lr_vec, centered_variance): + """ + Fit parameters for a vector of learning rates, with or without centered variance, + and calculate the convergence properties for each learning rate. + """ ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = [], [], [] for lr in lr_vec: loss_vec = fit(lr=lr, centered_variance=centered_variance) From 4d8d1c13a39cb516aa41d4193da5c2fb3475a228 Mon Sep 17 00:00:00 2001 From: Ben Zickel Date: Sat, 25 Jan 2025 13:29:09 +0200 Subject: [PATCH 7/7] Shortened the ClippedAdam centered variance test and added an option to run the full test with plots via a pytest command line option. --- tests/optim/conftest.py | 10 ++++++++++ tests/optim/test_optim.py | 21 ++++++++++++++++----- 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/tests/optim/conftest.py b/tests/optim/conftest.py index 55dd44d1d5..238deb176e 100644 --- a/tests/optim/conftest.py +++ b/tests/optim/conftest.py @@ -11,3 +11,13 @@ def pytest_collection_modifyitems(items): item.add_marker(pytest.mark.stage("unit")) if "init" not in item.keywords: item.add_marker(pytest.mark.init(rng_seed=123)) + + +def pytest_addoption(parser): + parser.addoption("--plot", action="store", default="FALSE") + + +def pytest_generate_tests(metafunc): + option_value = metafunc.config.option.plot != "FALSE" + if "plot" in metafunc.fixturenames and option_value is not None: + metafunc.parametrize("plot", [option_value]) diff --git a/tests/optim/test_optim.py b/tests/optim/test_optim.py index 5e745efddf..c0acefd7ef 100644 --- a/tests/optim/test_optim.py +++ b/tests/optim/test_optim.py @@ -437,7 +437,20 @@ def step(svi, optimizer): assert_equal(actual, expected) -def test_centered_clipped_adam(plot_results=False): +def test_centered_clipped_adam(plot): + """ + Test the centered variance option of the ClippedAdam optimizer. + In order to create plots run pytest with the plot command line + option set to True, i.e. by executing + + 'pytest tests/optim/test_optim.py::test_centered_clipped_adam --plot True' + + """ + if not plot: + lr_vec = [0.1, 0.001] + else: + lr_vec = [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001] + w = torch.Tensor([1, 500]) def loss_fn(p): @@ -484,14 +497,12 @@ def get_convergence_vec(lr_vec, centered_variance): ultimate_loss_vec.append(ultimate_loss) convergence_rate_vec.append(convergence_rate) convergence_iter_vec.append(convergence_iter) - print(lr, centered_variance, ultimate_loss, convergence_rate) return ( torch.Tensor(ultimate_loss_vec), torch.Tensor(convergence_rate_vec), convergence_iter_vec, ) - lr_vec = [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001] ( centered_ultimate_loss_vec, centered_convergence_rate_vec, @@ -508,10 +519,10 @@ def get_convergence_vec(lr_vec, centered_variance): # Verify convergence rate improvement assert ( (centered_convergence_rate_vec / convergence_rate_vec) - > (torch.Tensor([1.2] * len(lr_vec)).cumprod(0)) + > ((0.12 / torch.Tensor(lr_vec)).log() * 1.08) ).all() - if plot_results: + if plot: from matplotlib import pyplot as plt plt.figure(figsize=(6, 8))