From 5c01a0cc1e70c450d65c7d72042eefa514daf5df Mon Sep 17 00:00:00 2001 From: Mit Kotak Date: Sat, 27 Jul 2024 17:33:45 -0400 Subject: [PATCH] export working --- data.py | 21 +++++--- model.py | 23 +++++---- train.py | 16 ++----- utils.py | 142 +++++++++++++++++++++++++++++++++++++++---------------- 4 files changed, 133 insertions(+), 69 deletions(-) diff --git a/data.py b/data.py index 9bc6d49..3969309 100644 --- a/data.py +++ b/data.py @@ -1,8 +1,8 @@ import torch from torch_geometric.data import Data, Batch -from torch_geometric.transforms import RadiusGraph +from torch_geometric.nn import radius_graph -def get_random_graph(nodes, cutoff): +def get_random_graph(nodes, cutoff) -> Data: positions = torch.randn(nodes, 3) @@ -21,7 +21,7 @@ def get_random_graph(nodes, cutoff): graph = Data( pos = positions, relative_vectors = positions[receivers] - positions[senders], - numbers=z, + y=z, edge_index=edge_index, num_nodes=len(positions) ) @@ -43,12 +43,19 @@ def get_tetris() -> Batch: labels = torch.arange(8) graphs = [] - radius_graph = RadiusGraph(r=1.1) for p, l in zip(pos, labels): - data = Data(pos=p, y=l) - data = radius_graph(data) + edge_index = radius_graph(p, r=1.1) + data = Data(pos=pos, + y=l, + edge_index = edge_index, + relative_vectors = p[edge_index[0]] - p[edge_index[1]], + num_nodes = 4) graphs.append(data) - return Batch.from_data_list(graphs) \ No newline at end of file + batch = Batch.from_data_list(graphs) + batch.pos = batch.pos.view(-1, 3) + batch.relative_vectors = batch.relative_vectors.view(-1, 3) + + return batch \ No newline at end of file diff --git a/model.py b/model.py index dc872f9..e9a6b61 100644 --- a/model.py +++ b/model.py @@ -1,11 +1,12 @@ import torch +torch.jit.script = lambda x: x import torch.nn as nn import e3nn e3nn.set_optimization_defaults(jit_script_fx=False) from e3nn import o3 -# from torch_runstats.scatter import scatter_mean -from utils import scatter_mean + +from torch_runstats.scatter import scatter_mean class AtomEmbedding(nn.Module): @@ -88,12 +89,12 @@ def __init__(self, output_dims = self.output_dims) def forward(self, - numbers: torch.Tensor, + labels: torch.Tensor, relative_vectors: torch.Tensor, edge_index: torch.Tensor, num_nodes: int) -> torch.Tensor: - node_features = self.embed(numbers) + node_features = self.embed(labels) relative_vectors = relative_vectors senders, receivers = edge_index @@ -122,13 +123,11 @@ def forward(self, # Aggregate the node features back. - # print("src", node_features_broadcasted.shape) - # print("index", receivers.shape) - # print("dim", node_features.shape[0]) node_features = scatter_mean( node_features_broadcasted, - receivers, - dim = node_features.shape[0] + receivers.unsqueeze(-1).expand(-1, node_features_broadcasted.shape[-1]), + dim=0, + dim_size = node_features.shape[0] ) # # Global readout. @@ -136,5 +135,9 @@ def forward(self, # Filter out 0e node_features = self.filter_tp(node_features, self.dummy_input) - graph_globals = scatter_mean(node_features, output_dim=[num_nodes]) + + graph_globals = scatter_mean(node_features, + torch.zeros(num_nodes, dtype=torch.int64), + dim=0, + dim_size=1) return self.readout_mlp(graph_globals) \ No newline at end of file diff --git a/train.py b/train.py index 4ca434f..dd94ebd 100644 --- a/train.py +++ b/train.py @@ -2,31 +2,25 @@ sys.path.append("..") from model import SimpleNetwork -from data import get_random_graph +from data import get_tetris, get_random_graph import torch from e3nn import o3 torch._dynamo.config.capture_scalar_outputs = True -graph = get_random_graph(5, 2.5) +graph = get_random_graph(nodes=5, cutoff=1.5) model = SimpleNetwork( relative_vectors_irreps=o3.Irreps.spherical_harmonics(lmax=2), node_features_irreps=o3.Irreps("16x0e"), ) -model = torch.compile(model, fullgraph=True) - -model(graph.numbers, +args_in = (graph.y, graph.relative_vectors, graph.edge_index, graph.num_nodes) - -# model = torch.export.export(model, -# (graph.numbers, -# graph.relative_vectors, -# graph.edge_index, -# graph.num_nodes)) +es = torch.export.export(model, args_in) +print(es) diff --git a/utils.py b/utils.py index ff71991..7c09e04 100644 --- a/utils.py +++ b/utils.py @@ -1,56 +1,116 @@ -## Scatter mean function (courtesy of ChatGPT) +## Adapted from https://github.com/mir-group/pytorch_runstats -import torch +"""basic scatter operations from torch_scatter +Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency. -def scatter_mean(input, index=None, dim=None): - if index is not None: - # Case 1: Index is specified - output_size = index.max().tolist() + 1 - output = torch.zeros(output_size, input.size(1), device=input.device) - n = torch.zeros(output_size, device=input.device) +PyTorch plans to move these features into the main repo, but until then, +to make installation simpler, we need this pure python set of wrappers +that don't require installing PyTorch C++ extensions. +See https://github.com/pytorch/pytorch/issues/63780. +""" - for i in range(input.size(0)): - idx = index[i] - n[idx] += 1 - output[idx] += (input[i] - output[idx]) / n[idx] +from typing import Optional - return output +import torch - elif dim is not None: - # Case 2: Index is skipped, output_dim is specified - output = torch.zeros(len(dim), input.size(1), device=input.device) - start_idx = 0 - for i, dim in enumerate(dim): - end_idx = start_idx + dim - if dim > 0: - segment_sum = input[start_idx:end_idx].sum(dim=0) - output[i] = segment_sum / dim - start_idx = end_idx +def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int): + if dim < 0: + dim = other.dim() + dim + if src.dim() == 1: + for _ in range(0, dim): + src = src.unsqueeze(0) + for _ in range(src.dim(), other.dim()): + src = src.unsqueeze(-1) + src = src.expand_as(other) + return src - return output +def scatter( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + reduce: str = "sum", +) -> torch.Tensor: + assert reduce == "sum" # for now, TODO + index = _broadcast(index, src, dim) + if out is None: + size = list(src.size()) + if dim_size is not None: + size[dim] = dim_size + elif index.numel() == 0: + size[dim] = 0 + else: + size[dim] = int(index.max()) + 1 + out = torch.zeros(size, dtype=src.dtype, device=src.device) + return out.scatter_add_(dim, index, src) else: - raise ValueError("Either 'index' or 'dim' must be specified.") + return out.scatter_add_(dim, index, src) + + +def scatter_std( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, + unbiased: bool = True, +) -> torch.Tensor: + + if out is not None: + dim_size = out.size(dim) + + if dim < 0: + dim = src.dim() + dim + count_dim = dim + if index.dim() <= dim: + count_dim = index.dim() - 1 -# # Example usage for Case 1 (index specified): -# input1 = torch.randn(3000, 144) -# index1 = torch.randint(0, 1000, (3000,)) -# output1 = scatter_mean(input1, index=index1) -# print("Output shape (Case 1):", output1.shape) + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter(ones, index, count_dim, dim_size=dim_size) -# # Example usage for Case 2 (index skipped, output_dim specified): -# input2 = torch.randn(3000, 144) -# output_dim = [3000] -# output2 = scatter_mean(input2, dim=output_dim) -# print("Output shape (Case 2):", output2.shape) + index = _broadcast(index, src, dim) + tmp = scatter(src, index, dim, dim_size=dim_size) + count = _broadcast(count, tmp, dim).clamp(1) + mean = tmp.div(count) -# # Example usage for Case 3 (both spe): -# input = torch.randn(3000, 144) -# index = torch.randint(0, 1000, (3000,)) -# output_dim = [3000] + var = src - mean.gather(dim, index) + var = var * var + out = scatter(var, index, dim, out, dim_size) -# output = scatter_mean(input, index, output_dim) -# print(output.size()) # Should print torch.Size([1000, 144]) + if unbiased: + count = count.sub(1).clamp_(1) + out = out.div(count + 1e-6).sqrt() + + return out + + +def scatter_mean( + src: torch.Tensor, + index: torch.Tensor, + dim: int = -1, + out: Optional[torch.Tensor] = None, + dim_size: Optional[int] = None, +) -> torch.Tensor: + out = scatter(src, index, dim, out, dim_size) + dim_size = out.size(dim) + + index_dim = dim + if index_dim < 0: + index_dim = index_dim + src.dim() + if index.dim() <= index_dim: + index_dim = index.dim() - 1 + + ones = torch.ones(index.size(), dtype=src.dtype, device=src.device) + count = scatter(ones, index, index_dim, None, dim_size) + count[count < 1] = 1 + count = _broadcast(count, out, dim) + if out.is_floating_point(): + out.true_divide_(count) + else: + out.div_(count, rounding_mode="floor") + return out \ No newline at end of file