Skip to content

Commit

Permalink
Add sgw_torch package.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mark Hale committed Jul 22, 2024
1 parent b324b09 commit 8d692f4
Show file tree
Hide file tree
Showing 4 changed files with 293 additions and 1 deletion.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="sgw_tools",
version="2.5.5",
version="3.0",
author="Mark Hale",
license="MIT",
description="Spectral graph wavelet tools",
Expand Down
40 changes: 40 additions & 0 deletions sgw_torch/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Optional

import numpy as np
import torch
from torch import Tensor
import torch_geometric.utils as torch_utils
from torch_geometric.typing import OptTensor
from sgw_tools import util

from . import layers

def edge_tensors(G):
edge_list = G.get_edge_list()
edge_index = np.array([
np.concatenate((edge_list[0], edge_list[1])),
np.concatenate((edge_list[1], edge_list[0]))
])
edge_weight = np.concatenate((edge_list[2], edge_list[2]))
return torch.from_numpy(edge_index).long(), torch.from_numpy(edge_weight)


def get_laplacian(edge_index: Tensor,
edge_weight: OptTensor = None,
lap_type: str = "normalized",
dtype: Optional[torch.dtype] = None,
num_nodes: Optional[int] = None):
if lap_type == "adjacency":
if edge_weight is None:
edge_weight = torch.ones(edge_index.size[1], dtype=dtype, device=edge_index.device)
op_norm = util.operator_norm(torch_utils.to_scipy_sparse_matrix(edge_index, edge_weight))
num_nodes = torch_utils.num_nodes.maybe_num_nodes(edge_index, num_nodes)
return torch_utils.add_self_loops(edge_index, -edge_weight/op_norm, fill_value=1.0, num_nodes=num_nodes)
else:
normalization_mapping = {
"combinatorial": None,
"normalized": "sym"
}
normalization = normalization_mapping[lap_type]
return torch_utils.get_laplacian(edge_index, edge_weight=edge_weight, normalization=normalization, dtype=dtype, num_nodes=num_nodes)

216 changes: 216 additions & 0 deletions sgw_torch/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
from typing import Optional

import torch
from torch import Tensor
from torch.nn import Parameter
import torch.nn.functional as F

from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn.dense.linear import Linear
from torch_geometric.nn.inits import zeros
from torch_geometric.typing import OptTensor
import sgw_torch
import numpy as np


class SpectralLayer(MessagePassing):
def __init__(self,
in_channels: int,
out_channels: int,
lap_type: str,
):
super().__init__(aggr="add")
self.in_channels = in_channels
self.out_channels = out_channels
self.lap_type = lap_type

def get_laplacian(self,
edge_index: Tensor,
edge_weight: OptTensor = None,
lap_type: str = "normalized",
dtype: Optional[torch.dtype] = None,
num_nodes: Optional[int] = None
):
return sgw_torch.get_laplacian(edge_index, edge_weight=edge_weight, lap_type=lap_type, dtype=dtype, num_nodes=num_nodes)

def get_lambda_max(self, edge_weight: Tensor, num_nodes: int) -> Tensor:
if self.lap_type == "normalized" or self.lap_type == "adjacency":
return torch.tensor(2.0, dtype=edge_weight.dtype)
elif self.lap_type == "combinatorial":
D = edge_weight[edge_weight > 0]
A = -edge_weight[edge_weight < 0]
return min(num_nodes*A.max(), 2.0 * D.max())
else:
raise ValueError(f"Unsupported lap_type {self.lap_type}")

def message(self, x_j: Tensor, norm: Tensor) -> Tensor:
return norm.view(-1, 1) * x_j


class ChebLayer(SpectralLayer):
def __init__(
self,
in_channels: int,
out_channels: int,
K: int,
lap_type: str ='normalized',
bias: bool = True,
):
super().__init__(in_channels, out_channels, lap_type)
self.lins = torch.nn.ModuleList([
Linear(in_channels, out_channels, bias=False,
weight_initializer='glorot') for _ in range(K)
])

if bias:
self.bias = Parameter(Tensor(out_channels))
else:
self.register_parameter('bias', None)

self.reset_parameters()

def reset_parameters(self):
super().reset_parameters()
for lin in self.lins:
lin.reset_parameters()
zeros(self.bias)

def __norm__(
self,
edge_index: Tensor,
num_nodes: int,
edge_weight: OptTensor,
lambda_max: OptTensor = None,
dtype: Optional[int] = None,
batch: OptTensor = None,
):
edge_index, edge_weight = self.get_laplacian(edge_index, edge_weight,
self.lap_type, dtype,
num_nodes)
assert edge_weight is not None

if lambda_max is None:
lambda_max = self.get_lambda_max(edge_weight, num_nodes)
elif not isinstance(lambda_max, Tensor):
lambda_max = torch.tensor(lambda_max, dtype=dtype,
device=edge_index.device)
assert lambda_max is not None

if batch is not None and lambda_max.numel() > 1:
lambda_max = lambda_max[batch[edge_index[0]]]

edge_weight = (2.0 * edge_weight) / lambda_max
edge_weight.masked_fill_(edge_weight == float('inf'), 0)

loop_mask = edge_index[0] == edge_index[1]
edge_weight[loop_mask] -= 1

return edge_index, edge_weight

def _weight_op(self, x: Tensor, weight: Tensor):
return F.linear(x, weight)

def _evaluate_chebyshev(
self,
coeffs: list[Tensor],
x: Tensor,
edge_index: Tensor,
edge_weight: OptTensor = None,
batch: OptTensor = None,
lambda_max: OptTensor = None,
) -> Tensor:

edge_index, norm = self.__norm__(
edge_index,
x.size(self.node_dim),
edge_weight,
lambda_max,
dtype=x.dtype,
batch=batch,
)

Tx_0 = x
out = self._weight_op(Tx_0, coeffs[0])

# propagate_type: (x: Tensor, norm: Tensor)
if len(coeffs) > 1:
Tx_1 = self.propagate(edge_index, x=x, norm=norm)
out = out + self._weight_op(Tx_1, coeffs[1])

for coeff in coeffs[2:]:
Tx_2 = self.propagate(edge_index, x=Tx_1, norm=norm)
Tx_2 = 2.0 * Tx_2 - Tx_0
out = out + self._weight_op(Tx_2, coeff)
Tx_0, Tx_1 = Tx_1, Tx_2

if self.bias is not None:
out = out + self.bias

return out

def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_weight: OptTensor = None,
batch: OptTensor = None,
lambda_max: OptTensor = None,
) -> Tensor:
ws = [lin.weight for lin in self.lins]
return self._evaluate_chebyshev(ws, x, edge_index, edge_weight, batch, lambda_max)

def __repr__(self) -> str:
return (f'{self.__class__.__name__}({self.in_channels}, '
f'{self.out_channels}, K={len(self.lins)}, '
f'lap_type={self.lap_type})')


class ChebIILayer(ChebLayer):
"""
https://github.com/ivam-he/ChebNetII
"""
def __init__(
self,
in_channels: int,
out_channels: int,
K: int,
lap_type: str ='normalized',
bias: bool = True,
):
super().__init__(in_channels, out_channels, K, lap_type, bias)

def convert_coefficients(self, ys):
k1 = len(self.lins)

def transform(Ts):
itr = zip(ys, Ts)
first = next(itr)
w = first[0] * first[1] # initialise with correct shape
for y, T in itr:
w += y*T
return 2*w/k1

ws = []
x = torch.from_numpy(np.polynomial.chebyshev.chebpts1(k1))
a = torch.ones(k1)
b = x
ws.append(transform(a/2))
ws.append(transform(b))
for _ in range(2, k1):
T = 2*x*b - a
ws.append(transform(T))
a = b
b = T
return ws

def forward(
self,
x: Tensor,
edge_index: Tensor,
edge_weight: OptTensor = None,
batch: OptTensor = None,
lambda_max: OptTensor = None,
) -> Tensor:
ys = [lin.weight for lin in self.lins]
ws = self.convert_coefficients(ys)
return self._evaluate_cheb(ws, x, edge_index, edge_weight, batch, lambda_max)
36 changes: 36 additions & 0 deletions tests/test_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import unittest

import numpy as np
import sgw_tools as sgw
import sgw_torch
import pygsp as gsp
import torch
import torch_geometric.utils as torch_utils


class TestCase(unittest.TestCase):
def test_get_laplacian(self):
G = sgw.BigGraph.create_from(gsp.graphs.Sensor(50, seed=32), lap_type="adjacency")
edge_index, edge_weight = sgw_torch.edge_tensors(G)
L_index, L_weight = sgw_torch.get_laplacian(edge_index, edge_weight, lap_type="adjacency", num_nodes=G.N)
torch_L = torch_utils.to_scipy_sparse_matrix(L_index, L_weight)
np.testing.assert_allclose(G.L.toarray(), torch_L.toarray())

def test_ChebLayer(self):
G = gsp.graphs.Sensor(34, seed=506, lap_type="normalized")
K = 5
domain = [0, 2]
coeffs = sgw.approximations.compute_cheby_coeff(gsp.filters.Heat(G, tau=5), m=K, domain=domain)
coeffs[0] /= 2

s = sgw.createSignal(G, nodes=[0])
g = sgw.ChebyshevFilter(G, coeffs, domain, coeff_normalization="numpy")
expected = g.filter(s)

layer = sgw_torch.layers.ChebLayer(1, 1, K+1, lap_type="normalized", bias=False)
for c, p in zip(coeffs, layer.parameters()):
p.data = torch.tensor([[c]])
edge_index, edge_weight = sgw_torch.edge_tensors(G)
y = layer(torch.from_numpy(s), edge_index, edge_weight)

np.testing.assert_allclose(expected, y.detach().numpy().squeeze())

0 comments on commit 8d692f4

Please sign in to comment.