Skip to content

Commit

Permalink
Compute eps_primal and eps_dual only when needed
Browse files Browse the repository at this point in the history
To reduce the L-ADMM runtime, compute the variables `eps_pri` and
`eps_dual` only when needed by convergence check.
  • Loading branch information
antonysigma committed Jan 7, 2025
1 parent c53a7c4 commit 5c870ca
Showing 1 changed file with 50 additions and 20 deletions.
70 changes: 50 additions & 20 deletions proximal/algorithms/linearized_admm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,26 @@ def partition(prox_fns, try_diagonalize=True):
return psi_fns, omega_fns


def solve(psi_fns, omega_fns, lmb=1.0, mu=None, quad_funcs=None,
max_iters=1000, eps_abs=1e-3, eps_rel=1e-3,
lin_solver="cg", lin_solver_options=None,
implem=None,
try_diagonalize=True, try_fast_norm=True, scaled=False,
metric=None, convlog=None, verbose=0):
def solve(
psi_fns,
omega_fns,
lmb=1.0,
mu=None,
quad_funcs=None,
max_iters=1000,
eps_abs=1e-3,
eps_rel=1e-3,
lin_solver="cg",
lin_solver_options=None,
implem=None,
try_diagonalize=True,
try_fast_norm=True,
scaled=False,
metric=None,
convlog=None,
verbose=0,
conv_check=20,
):

# Can only have one omega function.
assert len(omega_fns) <= 1
Expand Down Expand Up @@ -120,36 +134,52 @@ def solve(psi_fns, omega_fns, lmb=1.0, mu=None, quad_funcs=None,
ne.evaluate('u + Kv - z', out=u)
K.adjoint(u, KTu)

# Check convergence.
r = ne.evaluate('Kv - z')
ztmp = ne.evaluate('(z - z_prev) / lmb')
K.adjoint(ztmp, s)
eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * \
max([np.linalg.norm(Kv.astype(np.float64)), np.linalg.norm(z.astype(np.float64))])
eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm(KTu.astype(np.float64)) / (1.0 / lmb)

# Convergence log
if convlog is not None:
convlog.toc()
K.update_vars(v)
objval = sum([fn.value for fn in prox_fns])
convlog.record_objective(objval)

should_check_convergence: bool = i % conv_check == 0
if should_check_convergence:
# Check convergence.
r = ne.evaluate("Kv - z")
ztmp = ne.evaluate("(z - z_prev) / lmb")
K.adjoint(ztmp, s)

eps_pri = np.sqrt(K.output_size) * eps_abs + eps_rel * max(
[
np.linalg.norm(Kv.astype(np.float64)),
np.linalg.norm(z.astype(np.float64)),
]
)
eps_dual = np.sqrt(K.input_size) * eps_abs + eps_rel * np.linalg.norm(
KTu.astype(np.float64)
) / (1.0 / lmb)

# Show progess
if verbose > 0 and i % 20 == 0:
if verbose > 0 and should_check_convergence:
# Evaluate objective only if required (expensive !)
objstr = ''
if verbose == 2:
K.update_vars(v)
objstr = ", obj_val = %02.03e" % sum([fn.value for fn in prox_fns])
objstr = f", obj_val = {sum([fn.value for fn in prox_fns]):02.03e}"

# Evaluate metric potentially
metstr = '' if metric is None else ", {}".format(metric.message(v))
print("iter %d: ||r||_2 = %.3f, eps_pri = %.3f, ||s||_2 = %.3f, eps_dual = %.3f%s%s" % (
i, np.linalg.norm(r), eps_pri, np.linalg.norm(s), eps_dual, objstr, metstr))
metstr = "" if metric is None else f", {metric.message(v)}"
print(
f"iter {i:d}: ||r||_2 = {np.linalg.norm(r):.3g}, eps_pri = {eps_pri:.3g}, "
f"||s||_2 = {np.linalg.norm(s):.3f}, eps_dual = {eps_dual:.3f}{objstr:s}{metstr:s}"
)

iter_timing.toc()
if np.linalg.norm(r) <= eps_pri and np.linalg.norm(s) <= eps_dual:
if (
i >= 1
and should_check_convergence
and np.linalg.norm(r) <= eps_pri
and np.linalg.norm(s) <= eps_dual
):
break

# Print out timings info.
Expand Down

0 comments on commit 5c870ca

Please sign in to comment.