From 3b330bd7267e3ddf5978f7dc771de185017123ae Mon Sep 17 00:00:00 2001 From: Mark Hale Date: Wed, 5 Jun 2024 00:44:31 +0100 Subject: [PATCH] Support tig for isolated nodes. --- setup.py | 2 +- sgw_tools/__init__.py | 8 ++++++-- tests/test_biggraph.py | 16 +++++++++------- tests/test_sgw.py | 35 +++++++++++++++++++++++++++++++++++ 4 files changed, 51 insertions(+), 10 deletions(-) create mode 100644 tests/test_sgw.py diff --git a/setup.py b/setup.py index 619fe87..cd2774e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="sgw_tools", - version="1.104", + version="1.105", author="Mark Hale", license="MIT", description="Spectral graph wavelet tools", diff --git a/sgw_tools/__init__.py b/sgw_tools/__init__.py index 8d6be7b..fc3b090 100644 --- a/sgw_tools/__init__.py +++ b/sgw_tools/__init__.py @@ -80,10 +80,14 @@ def _tig(g, s, **kwargs): if s.ndim == 1: s = s[..., np.newaxis] assert s.ndim == 2, "Signal shape was {}".format(s.shape) + if g.G.N == 1: + kwargs['method'] = 'exact' tig = g.filter(s[..., np.newaxis], **kwargs) - if s.shape[1] == 1: # single signal + if s.shape[1] == 1: # single signal tig = tig[..., np.newaxis] - if tig.ndim == 2: # single filter + if tig.ndim == 1: + tig = tig[np.newaxis, ..., np.newaxis] + elif tig.ndim == 2: # single filter tig = tig[..., np.newaxis] assert tig.shape == s.shape + (tig.shape[2],), "Tig shape was {}".format(tig.shape) return tig diff --git a/tests/test_biggraph.py b/tests/test_biggraph.py index 5ccccd8..a78938b 100644 --- a/tests/test_biggraph.py +++ b/tests/test_biggraph.py @@ -200,8 +200,9 @@ def test_estimate_lmin(self): np.testing.assert_approx_equal(graph.lmin, 0.08818931) # isolated node - graph = BigGraph([[3]]) - np.testing.assert_approx_equal(graph.lmin, np.nan) + for lap_type in ['combinatorial', 'normalized', 'adjacency']: + graph = BigGraph([[3]], lap_type=lap_type) + np.testing.assert_approx_equal(graph.lmin, np.nan) def test_estimate_lmax(self): graph = BigGraph.create_from(graphs.Sensor()) @@ -209,11 +210,11 @@ def test_estimate_lmax(self): def check_lmax(graph, lmax): graph.estimate_lmax(method='bounds') - np.testing.assert_allclose(graph.lmax, lmax) + np.testing.assert_allclose(graph.lmax, lmax, err_msg='bounds', atol=1e-15) graph.estimate_lmax(method='lanczos') - np.testing.assert_allclose(graph.lmax, lmax) + np.testing.assert_allclose(graph.lmax, lmax, err_msg='lanczos', atol=1e-15) graph.compute_fourier_basis() - np.testing.assert_allclose(graph.lmax, lmax) + np.testing.assert_allclose(graph.lmax, lmax, err_msg='fourier', atol=1e-15) # Full graph (bound is tight). n_nodes, value = 10, 2 @@ -242,8 +243,9 @@ def check_lmax(graph, lmax): check_lmax(graph, lmax=2) # isolated node - graph = BigGraph([[3]]) - check_lmax(graph, lmax=0) + for lap_type in ['combinatorial', 'normalized', 'adjacency']: + graph = BigGraph([[3]], lap_type=lap_type) + check_lmax(graph, lmax=0) def test_fourier_basis(self): # Smallest eigenvalue close to zero. diff --git a/tests/test_sgw.py b/tests/test_sgw.py new file mode 100644 index 0000000..02666ab --- /dev/null +++ b/tests/test_sgw.py @@ -0,0 +1,35 @@ +import unittest + +import numpy as np +import sgw_tools as sgw +import pygsp as gsp + + +class TestCase(unittest.TestCase): + def test_tig(self): + G = sgw.BigGraph([[0,1],[1,0]]) + g = gsp.filters.Heat(G) + s = sgw.createSignal(G) + sgw._tig(g, s) + + # single signal + s = np.array([[1],[-1]]) + sgw._tig(g, s) + + # multiple signals + s = np.array([[1,2,3],[-1,-2,-3]]) + sgw._tig(g, s) + + # isolated node + G = sgw.BigGraph([[2]]) + g = gsp.filters.Heat(G) + s = sgw.createSignal(G) + sgw._tig(g, s) + + # single signal + s = np.array([1]) + sgw._tig(g, s) + + # multiple signals + s = np.array([[1,2,3]]) + sgw._tig(g, s)