Skip to content

Commit

Permalink
Merge pull request epfl-lts2#47 from cgallay/diag_gird2d
Browse files Browse the repository at this point in the history
Grid graph: allow to connect diagonals
  • Loading branch information
mdeff authored Apr 8, 2019
2 parents 7b3884b + 18e999c commit 8ce5bde
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 1 deletion.
19 changes: 18 additions & 1 deletion pygsp/graphs/grid2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ class Grid2d(Graph):
Number of vertices along the first dimension.
N2 : int
Number of vertices along the second dimension. Default is ``N1``.
diagonal : float
Value of the diagnal edges. Default is ``0.0``
See Also
--------
Expand All @@ -36,7 +38,7 @@ class Grid2d(Graph):
"""

def __init__(self, N1=16, N2=None, **kwargs):
def __init__(self, N1=16, N2=None, diagonal=0.0, **kwargs):

if N2 is None:
N2 = N1
Expand All @@ -51,11 +53,26 @@ def __init__(self, N1=16, N2=None, **kwargs):
diag_1 = np.ones(N - 1)
diag_1[(N2 - 1)::N2] = 0
diag_2 = np.ones(N - N2)

W = sparse.diags(diagonals=[diag_1, diag_2],
offsets=[-1, -N2],
shape=(N, N),
format='csr',
dtype='float')

if min(N1, N2) > 1 and diagonal != 0.0:
# Connecting node with they diagonal neighbours
diag_3 = np.full(N - N2 - 1, diagonal)
diag_4 = np.full(N - N2 + 1, diagonal)
diag_3[N2 - 1::N2] = 0
diag_4[0::N2] = 0
D = sparse.diags(diagonals=[diag_3, diag_4],
offsets=[-N2 - 1, -N2 + 1],
shape=(N, N),
format='csr',
dtype='float')
W += D

W = utils.symmetrize(W, method='tril')

x = np.kron(np.ones((N1, 1)), (np.arange(N2)/float(N2)).reshape(N2, 1))
Expand Down
10 changes: 10 additions & 0 deletions pygsp/tests/test_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,6 +589,16 @@ def test_imgpatches(self):
def test_grid2dimgpatches(self):
graphs.Grid2dImgPatches(img=self._img, patch_shape=(3, 3))

def test_grid2d_diagonals(self):
value = 0.5
G = graphs.Grid2d(6, 7, diagonal=value)
self.assertEqual(G.W[2, 8], value)
self.assertEqual(G.W[9, 1], value)
self.assertEqual(G.W[9, 3], value)
self.assertEqual(G.W[2, 14], 0.0)
self.assertEqual(G.W[17, 1], 0.0)
self.assertEqual(G.W[9, 16], 1.0)
self.assertEqual(G.W[20, 27], 1.0)

suite_graphs = unittest.TestLoader().loadTestsFromTestCase(TestCase)

Expand Down

0 comments on commit 8ce5bde

Please sign in to comment.