Skip to content

Commit

Permalink
Add optimised Chebyshev filter.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Hale committed Jun 25, 2024
1 parent 0875af1 commit f1d3f54
Show file tree
Hide file tree
Showing 4 changed files with 106 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="sgw_tools",
version="2.2",
version="2.3",
author="Mark Hale",
license="MIT",
description="Spectral graph wavelet tools",
Expand Down
2 changes: 2 additions & 0 deletions sgw_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,10 @@
from .filters import GWHeat
from .filters import GaussianFilter
from .filters import ShiftedFilter
from .filters import ChebyshevFilter
from .filters import CustomFilter


"""
Refs: Shape Classification using Spectral Graph Wavelets/Spectral Geometric Methods for Deformable 3D Shape Retrieval
Expand Down
83 changes: 83 additions & 0 deletions sgw_tools/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,89 @@ def __init__(self, g, Nf=100, shifts=None):
super().__init__(G, kernels)


class ChebyshevFilter(gsp.filters.Filter):
def __init__(self, G, coeff_bank, domain, coeff_normalization="pygsp"):
coeff_bank = np.array(coeff_bank)
if coeff_bank.ndim == 1:
coeff_bank = coeff_bank.reshape(1, -1)
self.coeff_bank = coeff_bank

if coeff_normalization == "numpy":
self.coeff_bank[:, 0] *= 2

kernels = [
np.polynomial.Chebyshev(coeffs, domain=domain) for coeffs in coeff_bank
]
super().__init__(G, kernels)

def filter(self, s, method='chebyshev', order=30):
if s.shape[0] != self.G.N:
raise ValueError('First dimension should be the number of nodes '
'G.N = {}, got {}.'.format(self.G.N, s.shape))

# TODO: not in self.Nin (Nf = Nin x Nout).
if s.ndim == 1 or s.shape[-1] not in [1, self.Nf]:
if s.ndim == 3:
raise ValueError('Third dimension (#features) should be '
'either 1 or the number of filters Nf = {}, '
'got {}.'.format(self.Nf, s.shape))
s = np.expand_dims(s, -1)
n_features_in = s.shape[-1]

if s.ndim < 3:
s = np.expand_dims(s, 1)
n_signals = s.shape[1]

if s.ndim > 3:
raise ValueError('At most 3 dimensions: '
'#nodes x #signals x #features.')
assert s.ndim == 3

# TODO: generalize to 2D (m --> n) filter banks.
# Only 1 --> Nf (analysis) and Nf --> 1 (synthesis) for now.
n_features_out = self.Nf if n_features_in == 1 else 1

if method == 'exact':

# TODO: will be handled by g.adjoint().
axis = 1 if n_features_in == 1 else 2
f = self.evaluate(self.G.e)
f = np.expand_dims(f.T, axis)
assert f.shape == (self.G.N, n_features_in, n_features_out)

s = self.G.gft(s)
s = np.matmul(s, f)
s = self.G.igft(s)

elif method == 'chebyshev':

c = self.coeff_bank

if n_features_in == 1: # Analysis.
s = s.squeeze(axis=2)
s = gsp.filters.approximations.cheby_op(self.G, c, s)
s = s.reshape((self.G.N, n_features_out, n_signals), order='F')
s = s.swapaxes(1, 2)

elif n_features_in == self.Nf: # Synthesis.
s = s.swapaxes(1, 2)
s_in = s.reshape(
(self.G.N * n_features_in, n_signals), order='F')
s = np.zeros((self.G.N, n_signals))
tmpN = np.arange(self.G.N, dtype=int)
for i in range(n_features_in):
s += gsp.filters.approximations.cheby_op(self.G,
c[i],
s_in[i * self.G.N + tmpN])
s = np.expand_dims(s, 2)

else:
raise ValueError('Unknown method {}.'.format(method))

# Return a 1D signal if e.g. a 1D signal was filtered by one filter.
return s.squeeze()


class CustomFilter(gsp.filters.Filter):
def __init__(self, G, funcs, scales=1):
if not hasattr(funcs, '__iter__'):
Expand Down
20 changes: 20 additions & 0 deletions tests/test_sgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,23 @@ def test_tig(self):
g = gsp.filters.Heat(G, tau=[1,5])
s = np.array([[1,2,3]])
sgw._tig(g, s)

def test_chebyshev_filter(self):
G = gsp.graphs.Sensor(100, lap_type="normalized", seed=5)
func = lambda x: np.exp(-x**2)
signal = np.ones(G.N)
g = sgw.CustomFilter(G, func)
order = 20
expected = g.filter(signal, order=order)

gsp_coeffs = gsp.filters.compute_cheby_coeff(g, m=order)
gsp_g = sgw.ChebyshevFilter(G, gsp_coeffs, [0, G.lmax], "pygsp")
gsp_actual = gsp_g.filter(signal, order=order)
np.testing.assert_allclose(gsp_actual, expected, err_msg="pygsp coeffs")

domain = [0, 2]
np_cheby = np.polynomial.Chebyshev.fit(G.e, func(G.e), deg=order, domain=domain)
np_g = sgw.ChebyshevFilter(G, np_cheby.coef, domain, "numpy")
np_actual = np_g.filter(signal, order=order)
np.testing.assert_allclose(np_actual, expected, err_msg="numpy coeffs", rtol=0.1)

0 comments on commit f1d3f54

Please sign in to comment.