diff --git a/sgw_tools/__init__.py b/sgw_tools/__init__.py index 1e8f47c..bd06ac9 100644 --- a/sgw_tools/__init__.py +++ b/sgw_tools/__init__.py @@ -28,8 +28,10 @@ from .filters import GWHeat from .filters import GaussianFilter from .filters import ShiftedFilter +from .filters import ChebyshevFilter from .filters import CustomFilter + """ Refs: Shape Classification using Spectral Graph Wavelets/Spectral Geometric Methods for Deformable 3D Shape Retrieval diff --git a/sgw_tools/filters.py b/sgw_tools/filters.py index 151939a..9dae312 100644 --- a/sgw_tools/filters.py +++ b/sgw_tools/filters.py @@ -72,6 +72,89 @@ def __init__(self, g, Nf=100, shifts=None): super().__init__(G, kernels) +class ChebyshevFilter(gsp.filters.Filter): + def __init__(self, G, coeff_bank, domain, coeff_normalization="pygsp"): + coeff_bank = np.array(coeff_bank) + if coeff_bank.ndim == 1: + coeff_bank = coeff_bank.reshape(1, -1) + self.coeff_bank = coeff_bank + + if coeff_normalization == "numpy": + self.coeff_bank[:, 0] *= 2 + + kernels = [ + np.polynomial.Chebyshev(coeffs, domain=domain) for coeffs in coeff_bank + ] + super().__init__(G, kernels) + + def filter(self, s, method='chebyshev', order=30): + if s.shape[0] != self.G.N: + raise ValueError('First dimension should be the number of nodes ' + 'G.N = {}, got {}.'.format(self.G.N, s.shape)) + + # TODO: not in self.Nin (Nf = Nin x Nout). + if s.ndim == 1 or s.shape[-1] not in [1, self.Nf]: + if s.ndim == 3: + raise ValueError('Third dimension (#features) should be ' + 'either 1 or the number of filters Nf = {}, ' + 'got {}.'.format(self.Nf, s.shape)) + s = np.expand_dims(s, -1) + n_features_in = s.shape[-1] + + if s.ndim < 3: + s = np.expand_dims(s, 1) + n_signals = s.shape[1] + + if s.ndim > 3: + raise ValueError('At most 3 dimensions: ' + '#nodes x #signals x #features.') + assert s.ndim == 3 + + # TODO: generalize to 2D (m --> n) filter banks. + # Only 1 --> Nf (analysis) and Nf --> 1 (synthesis) for now. + n_features_out = self.Nf if n_features_in == 1 else 1 + + if method == 'exact': + + # TODO: will be handled by g.adjoint(). + axis = 1 if n_features_in == 1 else 2 + f = self.evaluate(self.G.e) + f = np.expand_dims(f.T, axis) + assert f.shape == (self.G.N, n_features_in, n_features_out) + + s = self.G.gft(s) + s = np.matmul(s, f) + s = self.G.igft(s) + + elif method == 'chebyshev': + + c = self.coeff_bank + + if n_features_in == 1: # Analysis. + s = s.squeeze(axis=2) + s = gsp.filters.approximations.cheby_op(self.G, c, s) + s = s.reshape((self.G.N, n_features_out, n_signals), order='F') + s = s.swapaxes(1, 2) + + elif n_features_in == self.Nf: # Synthesis. + s = s.swapaxes(1, 2) + s_in = s.reshape( + (self.G.N * n_features_in, n_signals), order='F') + s = np.zeros((self.G.N, n_signals)) + tmpN = np.arange(self.G.N, dtype=int) + for i in range(n_features_in): + s += gsp.filters.approximations.cheby_op(self.G, + c[i], + s_in[i * self.G.N + tmpN]) + s = np.expand_dims(s, 2) + + else: + raise ValueError('Unknown method {}.'.format(method)) + + # Return a 1D signal if e.g. a 1D signal was filtered by one filter. + return s.squeeze() + + class CustomFilter(gsp.filters.Filter): def __init__(self, G, funcs, scales=1): if not hasattr(funcs, '__iter__'): diff --git a/tests/test_sgw.py b/tests/test_sgw.py index ae30d0a..b3ddb2b 100644 --- a/tests/test_sgw.py +++ b/tests/test_sgw.py @@ -48,3 +48,23 @@ def test_tig(self): g = gsp.filters.Heat(G, tau=[1,5]) s = np.array([[1,2,3]]) sgw._tig(g, s) + + def test_chebyshev_filter(self): + G = gsp.graphs.Sensor(100, lap_type="normalized", seed=5) + func = lambda x: np.exp(-x**2) + signal = np.ones(G.N) + g = sgw.CustomFilter(G, func) + order = 20 + expected = g.filter(signal, order=order) + + gsp_coeffs = gsp.filters.compute_cheby_coeff(g, m=order) + gsp_g = sgw.ChebyshevFilter(G, gsp_coeffs, [0, G.lmax], "pygsp") + gsp_actual = gsp_g.filter(signal, order=order) + np.testing.assert_allclose(gsp_actual, expected, err_msg="pygsp coeffs") + + domain = [0, 2] + np_cheby = np.polynomial.Chebyshev.fit(G.e, func(G.e), deg=order, domain=domain) + np_g = sgw.ChebyshevFilter(G, np_cheby.coef, domain, "numpy") + np_actual = np_g.filter(signal, order=order) + np.testing.assert_allclose(np_actual, expected, err_msg="numpy coeffs", rtol=0.1) +