-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
91 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
#!/usr/bin/env python | ||
# -*- coding: utf-8 -*- | ||
# This file is part of the SCICO package. Details of the copyright | ||
# and user license can be found in the 'LICENSE.txt' file distributed | ||
# with the package. | ||
|
||
r""" | ||
SCICO Call Tracing | ||
================== | ||
This example demonstates the call tracing functionality provided by the | ||
[trace](../_autosummary/scico.trace.rst) module. It is based on the | ||
[non-negative BPDN example](sparsecode_nn_admm.rst). | ||
""" | ||
|
||
import numpy as np | ||
|
||
import jax | ||
|
||
import scico.numpy as snp | ||
from scico import functional, linop, loss | ||
from scico.optimize.admm import ADMM, MatrixSubproblemSolver | ||
from scico.trace import register_variable, trace_scico_calls | ||
from scico.util import device_info | ||
|
||
""" | ||
Initialize tracing. JIT must be disabled for correct tracing. | ||
""" | ||
jax.config.update("jax_disable_jit", True) | ||
trace_scico_calls() | ||
|
||
|
||
""" | ||
Create random dictionary, reference random sparse representation, and | ||
test signal consisting of the synthesis of the reference sparse | ||
representation. | ||
""" | ||
m = 32 # signal size | ||
n = 128 # dictionary size | ||
s = 10 # sparsity level | ||
|
||
np.random.seed(1) | ||
D = np.random.randn(m, n).astype(np.float32) | ||
D = D / np.linalg.norm(D, axis=0, keepdims=True) # normalize dictionary | ||
|
||
xt = np.zeros(n, dtype=np.float32) # true signal | ||
idx = np.random.randint(low=0, high=n, size=s) # support of xt | ||
xt[idx] = np.random.rand(s) | ||
y = D @ xt + 5e-2 * np.random.randn(m) # synthetic signal | ||
|
||
xt = snp.array(xt) # convert to jax array | ||
y = snp.array(y) # convert to jax array | ||
|
||
|
||
""" | ||
Set up the forward operator and ADMM solver object. | ||
""" | ||
lmbda = 1e-1 | ||
A = linop.MatrixOperator(D) | ||
f = loss.SquaredL2Loss(y=y, A=A) | ||
g_list = [lmbda * functional.L1Norm(), functional.NonNegativeIndicator()] | ||
C_list = [linop.Identity((n)), linop.Identity((n))] | ||
rho_list = [1.0, 1.0] | ||
maxiter = 10 # number of ADMM iterations | ||
|
||
solver = ADMM( | ||
f=f, | ||
g_list=g_list, | ||
C_list=C_list, | ||
rho_list=rho_list, | ||
x0=A.adj(y), | ||
maxiter=maxiter, | ||
subproblem_solver=MatrixSubproblemSolver(), | ||
itstat_options={"display": True, "period": 5}, | ||
) | ||
|
||
|
||
register_variable(A, "A") | ||
register_variable(f, "f") | ||
register_variable(g_list[0], "g_list[0]") | ||
register_variable(g_list[1], "g_list[1]") | ||
register_variable(C_list[0], "C_list[0]") | ||
register_variable(C_list[1], "C_list[1]") | ||
register_variable(solver, "solver") | ||
|
||
|
||
""" | ||
Run the solver. | ||
""" | ||
print(f"Solving on {device_info()}\n") | ||
x = solver.solve() |