Skip to content

Commit

Permalink
Refactoring.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Hale committed Jul 25, 2024
1 parent 799419e commit 2bbd610
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 18 deletions.
29 changes: 15 additions & 14 deletions sgw_tools/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,28 +110,29 @@ def cheby_op(G, c, signal, domain=None, **kwargs):
return r


def compute_cayley_coeff(f, m, method="complex"):
G = f.G
def fit_chebyshev_coeff(x, y, xmax, m):
p = np.polynomial.Chebyshev.fit(x, y, deg=m, domain=[0, xmax])
p.coef[np.isclose(p.coef, 0)] = 0
return p.coef


def fit_cayley_coeff(x, y, xmax, m, method="complex"):
if method == "complex":
res = optimize.minimize(cayley_loss, x0=np.array([1]), args=(f,m), bounds=((0, None),))
res = optimize.minimize(cayley_loss, x0=np.array([1]), args=(x,y,xmax,m), bounds=((0, None),))
h = res.x
z = util.cayley_transform(h*G.e)
y = f.evaluate(G.e).squeeze()
p = np.polynomial.polynomial.Polynomial.fit(z, y, m, domain=[0, G.lmax], window=[0, G.lmax])
return h[0], np.real(p.coef[0]), p.coef[1:]/2
z = util.cayley_transform(h*x)
p = np.polynomial.polynomial.Polynomial.fit(z, y, deg=m, domain=[0, xmax], window=[0, xmax])
return h[0], np.real(p.coef[0]), util.ctidy(p.coef[1:]/2)
elif method == "real":
y = f.evaluate(G.e).squeeze()
initial_guess = [1] + [1] + [0]*(m-1)
fit = optimize.curve_fit(util.cayley_filter, G.e, y, p0=initial_guess, jac=util.cayley_filter_jac)
fit = optimize.curve_fit(util.cayley_filter, x, y, p0=initial_guess, jac=util.cayley_filter_jac)
coeffs = fit[0]
return coeffs[0], coeffs[1], coeffs[2:]
else:
raise ValueError("Unsupported method")


def cayley_loss(h, f, m):
G = f.G
z = util.cayley_transform(h*G.e)
y = f.evaluate(G.e).squeeze()
p = np.polynomial.polynomial.Polynomial.fit(z, y, m, domain=[0, G.lmax], window=[0, G.lmax])
def cayley_loss(h, x, y, xmax, m):
z = util.cayley_transform(h*x)
p = np.polynomial.Polynomial.fit(z, y, m, domain=[0, xmax], window=[0, xmax])
return np.linalg.norm(y - p(z).real)
8 changes: 8 additions & 0 deletions sgw_tools/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,3 +192,11 @@ def cayley_filter_jac(x, h, c0, *c):
d_c.append(2 * np.real(ct_r))
d_h = 4 * x * np.sum(d_h_cts, axis=0)
return np.array([d_h, d_c0] + d_c).T


def ctidy(z):
no_imag = z[np.isclose(z.imag, 0)]
no_imag = no_imag.real
no_real = z[np.isclose(z.real, 0)]
no_real = no_real.imag * 1j
return z
6 changes: 3 additions & 3 deletions tests/test_sgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,8 +68,8 @@ def test_chebyshev_filter(self):
gsp_actual = gsp_g.filter(signal)
np.testing.assert_allclose(gsp_actual, expected, err_msg="pygsp coeffs")

np_cheby = np.polynomial.Chebyshev.fit(G.e, func(G.e), deg=order, domain=domain)
np_g = sgw.ChebyshevFilter(G, np_cheby.coef, domain, "numpy")
np_coeffs = sgw.approximations.fit_chebyshev_coeff(G.e, func(G.e), domain[1], order)
np_g = sgw.ChebyshevFilter(G, np_coeffs, domain, "numpy")
np.testing.assert_allclose(np_g.evaluate(G.e).squeeze(), func_e, err_msg="numpy evaluate")
np_actual = np_g.filter(signal)
np.testing.assert_allclose(np_actual, expected, err_msg="numpy coeffs")
Expand All @@ -83,7 +83,7 @@ def test_cayley_filter(self):
expected = g.filter(signal, method="exact")

order = 20
h, c0, c = sgw.approximations.compute_cayley_coeff(g, order, method="real")
h, c0, c = sgw.approximations.fit_cayley_coeff(G.e, func(G.e), G.lmax, order, method="real")

g = sgw.CayleyFilter(G, np.array([h, c0] + list(c)))
np.testing.assert_allclose(g.evaluate(G.e).squeeze(), func(G.e))
Expand Down
3 changes: 2 additions & 1 deletion tests/test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ def test_ChebLayer(self):
def test_CayleyLayer(self):
G = gsp.graphs.Sensor(34, seed=506, lap_type="normalized")
K = 8
h, c0, c = sgw.approximations.compute_cayley_coeff(gsp.filters.Heat(G, tau=5), m=K)
func = lambda x: np.exp(-5*x/G.lmax)
h, c0, c = sgw.approximations.fit_cayley_coeff(G.e, func(G.e), G.lmax, m=K)

s = sgw.createSignal(G, nodes=[0])
g = sgw.CayleyFilter(G, np.array([h, c0] + list(c)))
Expand Down
5 changes: 5 additions & 0 deletions tests/test_util.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import unittest

import numpy as np
from scipy import sparse
from sgw_tools import util

Expand All @@ -18,3 +19,7 @@ def test_count_negatives(self):
def test_operator_norm(self):
W = sparse.csr_matrix([[4.1]])
self.assertEqual(util.operator_norm(W), 4.1)

def test_ctidy(self):
z = np.array([5 + 1e-16j, 1e-16 + 3j])
np.testing.assert_allclose([5, 3j], util.ctidy(z))

0 comments on commit 2bbd610

Please sign in to comment.