Skip to content

Commit

Permalink
Support tig for isolated nodes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Hale committed Jun 4, 2024
1 parent 0696229 commit 3b330bd
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 10 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="sgw_tools",
version="1.104",
version="1.105",
author="Mark Hale",
license="MIT",
description="Spectral graph wavelet tools",
Expand Down
8 changes: 6 additions & 2 deletions sgw_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,14 @@ def _tig(g, s, **kwargs):
if s.ndim == 1:
s = s[..., np.newaxis]
assert s.ndim == 2, "Signal shape was {}".format(s.shape)
if g.G.N == 1:
kwargs['method'] = 'exact'
tig = g.filter(s[..., np.newaxis], **kwargs)
if s.shape[1] == 1: # single signal
if s.shape[1] == 1: # single signal
tig = tig[..., np.newaxis]
if tig.ndim == 2: # single filter
if tig.ndim == 1:
tig = tig[np.newaxis, ..., np.newaxis]
elif tig.ndim == 2: # single filter
tig = tig[..., np.newaxis]
assert tig.shape == s.shape + (tig.shape[2],), "Tig shape was {}".format(tig.shape)
return tig
Expand Down
16 changes: 9 additions & 7 deletions tests/test_biggraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,20 +200,21 @@ def test_estimate_lmin(self):
np.testing.assert_approx_equal(graph.lmin, 0.08818931)

# isolated node
graph = BigGraph([[3]])
np.testing.assert_approx_equal(graph.lmin, np.nan)
for lap_type in ['combinatorial', 'normalized', 'adjacency']:
graph = BigGraph([[3]], lap_type=lap_type)
np.testing.assert_approx_equal(graph.lmin, np.nan)

def test_estimate_lmax(self):
graph = BigGraph.create_from(graphs.Sensor())
self.assertRaises(ValueError, graph.estimate_lmax, method='unk')

def check_lmax(graph, lmax):
graph.estimate_lmax(method='bounds')
np.testing.assert_allclose(graph.lmax, lmax)
np.testing.assert_allclose(graph.lmax, lmax, err_msg='bounds', atol=1e-15)
graph.estimate_lmax(method='lanczos')
np.testing.assert_allclose(graph.lmax, lmax)
np.testing.assert_allclose(graph.lmax, lmax, err_msg='lanczos', atol=1e-15)
graph.compute_fourier_basis()
np.testing.assert_allclose(graph.lmax, lmax)
np.testing.assert_allclose(graph.lmax, lmax, err_msg='fourier', atol=1e-15)

# Full graph (bound is tight).
n_nodes, value = 10, 2
Expand Down Expand Up @@ -242,8 +243,9 @@ def check_lmax(graph, lmax):
check_lmax(graph, lmax=2)

# isolated node
graph = BigGraph([[3]])
check_lmax(graph, lmax=0)
for lap_type in ['combinatorial', 'normalized', 'adjacency']:
graph = BigGraph([[3]], lap_type=lap_type)
check_lmax(graph, lmax=0)

def test_fourier_basis(self):
# Smallest eigenvalue close to zero.
Expand Down
35 changes: 35 additions & 0 deletions tests/test_sgw.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import unittest

import numpy as np
import sgw_tools as sgw
import pygsp as gsp


class TestCase(unittest.TestCase):
def test_tig(self):
G = sgw.BigGraph([[0,1],[1,0]])
g = gsp.filters.Heat(G)
s = sgw.createSignal(G)
sgw._tig(g, s)

# single signal
s = np.array([[1],[-1]])
sgw._tig(g, s)

# multiple signals
s = np.array([[1,2,3],[-1,-2,-3]])
sgw._tig(g, s)

# isolated node
G = sgw.BigGraph([[2]])
g = gsp.filters.Heat(G)
s = sgw.createSignal(G)
sgw._tig(g, s)

# single signal
s = np.array([1])
sgw._tig(g, s)

# multiple signals
s = np.array([[1,2,3]])
sgw._tig(g, s)

0 comments on commit 3b330bd

Please sign in to comment.