Skip to content

Commit

Permalink
Fix handling of tig shape.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Hale committed Jun 5, 2024
1 parent 3b330bd commit e63c35d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="sgw_tools",
version="1.105",
version="2.0",
author="Mark Hale",
license="MIT",
description="Spectral graph wavelet tools",
Expand Down
27 changes: 20 additions & 7 deletions sgw_tools/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,31 @@ def _tig(g, s, **kwargs):
if g.G.is_directed() and g.G.q != 0 and ('method' not in kwargs or kwargs['method'] != 'exact'):
raise Exception("Only method='exact' is currently supported for magnetic Laplacians.")

if s.ndim == 1:
if s.ndim == 1: # single signal
s = s[..., np.newaxis]
assert s.ndim == 2, "Signal shape was {}".format(s.shape)

if g.G.N == 1:
kwargs['method'] = 'exact'
tig = g.filter(s[..., np.newaxis], **kwargs)
if s.shape[1] == 1: # single signal
tig = tig[..., np.newaxis]
if tig.ndim == 1:
tig = tig[np.newaxis, ..., np.newaxis]
elif tig.ndim == 2: # single filter
tig = tig[..., np.newaxis]

if tig.ndim == 0:
tig = tig[np.newaxis, np.newaxis, np.newaxis]
elif tig.ndim == 1:
if s.shape == (1,1): # multiple filters
tig = tig[np.newaxis, np.newaxis, ...]
elif s.shape[0] == 1: # single node
tig = tig[np.newaxis, ..., np.newaxis]
elif s.shape[1] == 1: # single signal
tig = tig[:, np.newaxis, np.newaxis]
elif tig.ndim == 2:
if s.shape[0] == 1: # single node
tig = tig[np.newaxis, ...]
elif s.shape[1] == 1: # single signal
tig = tig[:, np.newaxis, :]
else: # single filter
tig = tig[..., np.newaxis]

assert tig.shape == s.shape + (tig.shape[2],), "Tig shape was {}".format(tig.shape)
return tig

Expand Down
15 changes: 15 additions & 0 deletions tests/test_sgw.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@ def test_tig(self):
s = np.array([[1,2,3],[-1,-2,-3]])
sgw._tig(g, s)

# multiple filters, single signal
g = gsp.filters.Heat(G, tau=[1,5])
s = np.array([[1],[-1]])
sgw._tig(g, s)

# isolated node
G = sgw.BigGraph([[2]])
g = gsp.filters.Heat(G)
Expand All @@ -33,3 +38,13 @@ def test_tig(self):
# multiple signals
s = np.array([[1,2,3]])
sgw._tig(g, s)

# multiple filters, single signal
g = gsp.filters.Heat(G, tau=[1,5])
s = np.array([1])
sgw._tig(g, s)

# multiple filters, multiple signals
g = gsp.filters.Heat(G, tau=[1,5])
s = np.array([[1,2,3]])
sgw._tig(g, s)

0 comments on commit e63c35d

Please sign in to comment.