From 5c870caa6d28fae4a98ea6964d54471b6cf57727 Mon Sep 17 00:00:00 2001 From: Antony Chan Date: Tue, 7 Jan 2025 09:56:20 -0800 Subject: [PATCH] Compute eps_primal and eps_dual only when needed To reduce the L-ADMM runtime, compute the variables `eps_pri` and `eps_dual` only when needed by convergence check. --- proximal/algorithms/linearized_admm.py | 70 ++++++++++++++++++-------- 1 file changed, 50 insertions(+), 20 deletions(-) diff --git a/proximal/algorithms/linearized_admm.py b/proximal/algorithms/linearized_admm.py index 95524bc..b2a8219 100644 --- a/proximal/algorithms/linearized_admm.py +++ b/proximal/algorithms/linearized_admm.py @@ -41,12 +41,26 @@ def partition(prox_fns, try_diagonalize=True): return psi_fns, omega_fns -def solve(psi_fns, omega_fns, lmb=1.0, mu=None, quad_funcs=None, - max_iters=1000, eps_abs=1e-3, eps_rel=1e-3, - lin_solver="cg", lin_solver_options=None, - implem=None, - try_diagonalize=True, try_fast_norm=True, scaled=False, - metric=None, convlog=None, verbose=0): +def solve( + psi_fns, + omega_fns, + lmb=1.0, + mu=None, + quad_funcs=None, + max_iters=1000, + eps_abs=1e-3, + eps_rel=1e-3, + lin_solver="cg", + lin_solver_options=None, + implem=None, + try_diagonalize=True, + try_fast_norm=True, + scaled=False, + metric=None, + convlog=None, + verbose=0, + conv_check=20, +): # Can only have one omega function. assert len(omega_fns) <= 1 @@ -120,14 +134,6 @@ def solve(psi_fns, omega_fns, lmb=1.0, mu=None, quad_funcs=None, ne.evaluate('u + Kv - z', out=u) K.adjoint(u, KTu) - # Check convergence. - r = ne.evaluate('Kv - z') - ztmp = ne.evaluate('(z - z_prev) / lmb') - K.adjoint(ztmp, s) - eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \ - max([np.linalg.norm(Kv.astype(np.float64)), np.linalg.norm(z.astype(np.float64))]) - eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm(KTu.astype(np.float64)) / (1.0 / lmb) - # Convergence log if convlog is not None: convlog.toc() @@ -135,21 +141,45 @@ def solve(psi_fns, omega_fns, lmb=1.0, mu=None, quad_funcs=None, objval = sum([fn.value for fn in prox_fns]) convlog.record_objective(objval) + should_check_convergence: bool = i % conv_check == 0 + if should_check_convergence: + # Check convergence. + r = ne.evaluate("Kv - z") + ztmp = ne.evaluate("(z - z_prev) / lmb") + K.adjoint(ztmp, s) + + eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * max( + [ + np.linalg.norm(Kv.astype(np.float64)), + np.linalg.norm(z.astype(np.float64)), + ] + ) + eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm( + KTu.astype(np.float64) + ) / (1.0 / lmb) + # Show progess - if verbose > 0 and i % 20 == 0: + if verbose > 0 and should_check_convergence: # Evaluate objective only if required (expensive !) objstr = '' if verbose == 2: K.update_vars(v) - objstr = ", obj_val = %02.03e" % sum([fn.value for fn in prox_fns]) + objstr = f", obj_val = {sum([fn.value for fn in prox_fns]):02.03e}" # Evaluate metric potentially - metstr = '' if metric is None else ", {}".format(metric.message(v)) - print("iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % ( - i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr)) + metstr = "" if metric is None else f", {metric.message(v)}" + print( + f"iter {i:d}: ||r||_2 = {np.linalg.norm(r):.3g}, eps_pri = {eps_pri:.3g}, " + f"||s||_2 = {np.linalg.norm(s):.3f}, eps_dual = {eps_dual:.3f}{objstr:s}{metstr:s}" + ) iter_timing.toc() - if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual: + if ( + i >= 1 + and should_check_convergence + and np.linalg.norm(r) <= eps_pri + and np.linalg.norm(s) <= eps_dual + ): break # Print out timings info.