diff --git a/setup.py b/setup.py index 2b3aadc..d4c812e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="sgw_tools", - version="2.5.5", + version="3.0", author="Mark Hale", license="MIT", description="Spectral graph wavelet tools", diff --git a/sgw_torch/__init__.py b/sgw_torch/__init__.py new file mode 100644 index 0000000..a280bf5 --- /dev/null +++ b/sgw_torch/__init__.py @@ -0,0 +1,40 @@ +from typing import Optional + +import numpy as np +import torch +from torch import Tensor +import torch_geometric.utils as torch_utils +from torch_geometric.typing import OptTensor +from sgw_tools import util + +from . import layers + +def edge_tensors(G): + edge_list = G.get_edge_list() + edge_index = np.array([ + np.concatenate((edge_list[0], edge_list[1])), + np.concatenate((edge_list[1], edge_list[0])) + ]) + edge_weight = np.concatenate((edge_list[2], edge_list[2])) + return torch.from_numpy(edge_index).long(), torch.from_numpy(edge_weight) + + +def get_laplacian(edge_index: Tensor, + edge_weight: OptTensor = None, + lap_type: str = "normalized", + dtype: Optional[torch.dtype] = None, + num_nodes: Optional[int] = None): + if lap_type == "adjacency": + if edge_weight is None: + edge_weight = torch.ones(edge_index.size[1], dtype=dtype, device=edge_index.device) + op_norm = util.operator_norm(torch_utils.to_scipy_sparse_matrix(edge_index, edge_weight)) + num_nodes = torch_utils.num_nodes.maybe_num_nodes(edge_index, num_nodes) + return torch_utils.add_self_loops(edge_index, -edge_weight/op_norm, fill_value=1.0, num_nodes=num_nodes) + else: + normalization_mapping = { + "combinatorial": None, + "normalized": "sym" + } + normalization = normalization_mapping[lap_type] + return torch_utils.get_laplacian(edge_index, edge_weight=edge_weight, normalization=normalization, dtype=dtype, num_nodes=num_nodes) + \ No newline at end of file diff --git a/sgw_torch/layers.py b/sgw_torch/layers.py new file mode 100644 index 0000000..3aebee2 --- /dev/null +++ b/sgw_torch/layers.py @@ -0,0 +1,216 @@ +from typing import Optional + +import torch +from torch import Tensor +from torch.nn import Parameter +import torch.nn.functional as F + +from torch_geometric.nn.conv import MessagePassing +from torch_geometric.nn.dense.linear import Linear +from torch_geometric.nn.inits import zeros +from torch_geometric.typing import OptTensor +import sgw_torch +import numpy as np + + +class SpectralLayer(MessagePassing): + def __init__(self, + in_channels: int, + out_channels: int, + lap_type: str, + ): + super().__init__(aggr="add") + self.in_channels = in_channels + self.out_channels = out_channels + self.lap_type = lap_type + + def get_laplacian(self, + edge_index: Tensor, + edge_weight: OptTensor = None, + lap_type: str = "normalized", + dtype: Optional[torch.dtype] = None, + num_nodes: Optional[int] = None + ): + return sgw_torch.get_laplacian(edge_index, edge_weight=edge_weight, lap_type=lap_type, dtype=dtype, num_nodes=num_nodes) + + def get_lambda_max(self, edge_weight: Tensor, num_nodes: int) -> Tensor: + if self.lap_type == "normalized" or self.lap_type == "adjacency": + return torch.tensor(2.0, dtype=edge_weight.dtype) + elif self.lap_type == "combinatorial": + D = edge_weight[edge_weight > 0] + A = -edge_weight[edge_weight < 0] + return min(num_nodes*A.max(), 2.0 * D.max()) + else: + raise ValueError(f"Unsupported lap_type {self.lap_type}") + + def message(self, x_j: Tensor, norm: Tensor) -> Tensor: + return norm.view(-1, 1) * x_j + + +class ChebLayer(SpectralLayer): + def __init__( + self, + in_channels: int, + out_channels: int, + K: int, + lap_type: str ='normalized', + bias: bool = True, + ): + super().__init__(in_channels, out_channels, lap_type) + self.lins = torch.nn.ModuleList([ + Linear(in_channels, out_channels, bias=False, + weight_initializer='glorot') for _ in range(K) + ]) + + if bias: + self.bias = Parameter(Tensor(out_channels)) + else: + self.register_parameter('bias', None) + + self.reset_parameters() + + def reset_parameters(self): + super().reset_parameters() + for lin in self.lins: + lin.reset_parameters() + zeros(self.bias) + + def __norm__( + self, + edge_index: Tensor, + num_nodes: int, + edge_weight: OptTensor, + lambda_max: OptTensor = None, + dtype: Optional[int] = None, + batch: OptTensor = None, + ): + edge_index, edge_weight = self.get_laplacian(edge_index, edge_weight, + self.lap_type, dtype, + num_nodes) + assert edge_weight is not None + + if lambda_max is None: + lambda_max = self.get_lambda_max(edge_weight, num_nodes) + elif not isinstance(lambda_max, Tensor): + lambda_max = torch.tensor(lambda_max, dtype=dtype, + device=edge_index.device) + assert lambda_max is not None + + if batch is not None and lambda_max.numel() > 1: + lambda_max = lambda_max[batch[edge_index[0]]] + + edge_weight = (2.0 * edge_weight) / lambda_max + edge_weight.masked_fill_(edge_weight == float('inf'), 0) + + loop_mask = edge_index[0] == edge_index[1] + edge_weight[loop_mask] -= 1 + + return edge_index, edge_weight + + def _weight_op(self, x: Tensor, weight: Tensor): + return F.linear(x, weight) + + def _evaluate_chebyshev( + self, + coeffs: list[Tensor], + x: Tensor, + edge_index: Tensor, + edge_weight: OptTensor = None, + batch: OptTensor = None, + lambda_max: OptTensor = None, + ) -> Tensor: + + edge_index, norm = self.__norm__( + edge_index, + x.size(self.node_dim), + edge_weight, + lambda_max, + dtype=x.dtype, + batch=batch, + ) + + Tx_0 = x + out = self._weight_op(Tx_0, coeffs[0]) + + # propagate_type: (x: Tensor, norm: Tensor) + if len(coeffs) > 1: + Tx_1 = self.propagate(edge_index, x=x, norm=norm) + out = out + self._weight_op(Tx_1, coeffs[1]) + + for coeff in coeffs[2:]: + Tx_2 = self.propagate(edge_index, x=Tx_1, norm=norm) + Tx_2 = 2.0 * Tx_2 - Tx_0 + out = out + self._weight_op(Tx_2, coeff) + Tx_0, Tx_1 = Tx_1, Tx_2 + + if self.bias is not None: + out = out + self.bias + + return out + + def forward( + self, + x: Tensor, + edge_index: Tensor, + edge_weight: OptTensor = None, + batch: OptTensor = None, + lambda_max: OptTensor = None, + ) -> Tensor: + ws = [lin.weight for lin in self.lins] + return self._evaluate_chebyshev(ws, x, edge_index, edge_weight, batch, lambda_max) + + def __repr__(self) -> str: + return (f'{self.__class__.__name__}({self.in_channels}, ' + f'{self.out_channels}, K={len(self.lins)}, ' + f'lap_type={self.lap_type})') + + +class ChebIILayer(ChebLayer): + """ + https://github.com/ivam-he/ChebNetII + """ + def __init__( + self, + in_channels: int, + out_channels: int, + K: int, + lap_type: str ='normalized', + bias: bool = True, + ): + super().__init__(in_channels, out_channels, K, lap_type, bias) + + def convert_coefficients(self, ys): + k1 = len(self.lins) + + def transform(Ts): + itr = zip(ys, Ts) + first = next(itr) + w = first[0] * first[1] # initialise with correct shape + for y, T in itr: + w += y*T + return 2*w/k1 + + ws = [] + x = torch.from_numpy(np.polynomial.chebyshev.chebpts1(k1)) + a = torch.ones(k1) + b = x + ws.append(transform(a/2)) + ws.append(transform(b)) + for _ in range(2, k1): + T = 2*x*b - a + ws.append(transform(T)) + a = b + b = T + return ws + + def forward( + self, + x: Tensor, + edge_index: Tensor, + edge_weight: OptTensor = None, + batch: OptTensor = None, + lambda_max: OptTensor = None, + ) -> Tensor: + ys = [lin.weight for lin in self.lins] + ws = self.convert_coefficients(ys) + return self._evaluate_cheb(ws, x, edge_index, edge_weight, batch, lambda_max) diff --git a/tests/test_torch.py b/tests/test_torch.py new file mode 100644 index 0000000..844f44d --- /dev/null +++ b/tests/test_torch.py @@ -0,0 +1,36 @@ +import unittest + +import numpy as np +import sgw_tools as sgw +import sgw_torch +import pygsp as gsp +import torch +import torch_geometric.utils as torch_utils + + +class TestCase(unittest.TestCase): + def test_get_laplacian(self): + G = sgw.BigGraph.create_from(gsp.graphs.Sensor(50, seed=32), lap_type="adjacency") + edge_index, edge_weight = sgw_torch.edge_tensors(G) + L_index, L_weight = sgw_torch.get_laplacian(edge_index, edge_weight, lap_type="adjacency", num_nodes=G.N) + torch_L = torch_utils.to_scipy_sparse_matrix(L_index, L_weight) + np.testing.assert_allclose(G.L.toarray(), torch_L.toarray()) + + def test_ChebLayer(self): + G = gsp.graphs.Sensor(34, seed=506, lap_type="normalized") + K = 5 + domain = [0, 2] + coeffs = sgw.approximations.compute_cheby_coeff(gsp.filters.Heat(G, tau=5), m=K, domain=domain) + coeffs[0] /= 2 + + s = sgw.createSignal(G, nodes=[0]) + g = sgw.ChebyshevFilter(G, coeffs, domain, coeff_normalization="numpy") + expected = g.filter(s) + + layer = sgw_torch.layers.ChebLayer(1, 1, K+1, lap_type="normalized", bias=False) + for c, p in zip(coeffs, layer.parameters()): + p.data = torch.tensor([[c]]) + edge_index, edge_weight = sgw_torch.edge_tensors(G) + y = layer(torch.from_numpy(s), edge_index, edge_weight) + + np.testing.assert_allclose(expected, y.detach().numpy().squeeze())