Skip to content

Commit

Permalink
export working
Browse files Browse the repository at this point in the history
  • Loading branch information
mitkotak committed Jul 27, 2024
1 parent d46b42a commit 5c01a0c
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 69 deletions.
21 changes: 14 additions & 7 deletions data.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import torch
from torch_geometric.data import Data, Batch
from torch_geometric.transforms import RadiusGraph
from torch_geometric.nn import radius_graph

def get_random_graph(nodes, cutoff):
def get_random_graph(nodes, cutoff) -> Data:

positions = torch.randn(nodes, 3)

Expand All @@ -21,7 +21,7 @@ def get_random_graph(nodes, cutoff):
graph = Data(
pos = positions,
relative_vectors = positions[receivers] - positions[senders],
numbers=z,
y=z,
edge_index=edge_index,
num_nodes=len(positions)
)
Expand All @@ -43,12 +43,19 @@ def get_tetris() -> Batch:
labels = torch.arange(8)

graphs = []
radius_graph = RadiusGraph(r=1.1)

for p, l in zip(pos, labels):
data = Data(pos=p, y=l)
data = radius_graph(data)
edge_index = radius_graph(p, r=1.1)
data = Data(pos=pos,
y=l,
edge_index = edge_index,
relative_vectors = p[edge_index[0]] - p[edge_index[1]],
num_nodes = 4)

graphs.append(data)

return Batch.from_data_list(graphs)
batch = Batch.from_data_list(graphs)
batch.pos = batch.pos.view(-1, 3)
batch.relative_vectors = batch.relative_vectors.view(-1, 3)

return batch
23 changes: 13 additions & 10 deletions model.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import torch
torch.jit.script = lambda x: x
import torch.nn as nn
import e3nn
e3nn.set_optimization_defaults(jit_script_fx=False)

from e3nn import o3
# from torch_runstats.scatter import scatter_mean
from utils import scatter_mean

from torch_runstats.scatter import scatter_mean


class AtomEmbedding(nn.Module):
Expand Down Expand Up @@ -88,12 +89,12 @@ def __init__(self,
output_dims = self.output_dims)

def forward(self,
numbers: torch.Tensor,
labels: torch.Tensor,
relative_vectors: torch.Tensor,
edge_index: torch.Tensor,
num_nodes: int) -> torch.Tensor:

node_features = self.embed(numbers)
node_features = self.embed(labels)
relative_vectors = relative_vectors
senders, receivers = edge_index

Expand Down Expand Up @@ -122,19 +123,21 @@ def forward(self,


# Aggregate the node features back.
# print("src", node_features_broadcasted.shape)
# print("index", receivers.shape)
# print("dim", node_features.shape[0])
node_features = scatter_mean(
node_features_broadcasted,
receivers,
dim = node_features.shape[0]
receivers.unsqueeze(-1).expand(-1, node_features_broadcasted.shape[-1]),
dim=0,
dim_size = node_features.shape[0]
)

# # Global readout.

# Filter out 0e
node_features = self.filter_tp(node_features, self.dummy_input)

graph_globals = scatter_mean(node_features, output_dim=[num_nodes])

graph_globals = scatter_mean(node_features,
torch.zeros(num_nodes, dtype=torch.int64),
dim=0,
dim_size=1)
return self.readout_mlp(graph_globals)
16 changes: 5 additions & 11 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,25 @@
sys.path.append("..")

from model import SimpleNetwork
from data import get_random_graph
from data import get_tetris, get_random_graph
import torch
from e3nn import o3

torch._dynamo.config.capture_scalar_outputs = True

graph = get_random_graph(5, 2.5)
graph = get_random_graph(nodes=5, cutoff=1.5)
model = SimpleNetwork(
relative_vectors_irreps=o3.Irreps.spherical_harmonics(lmax=2),
node_features_irreps=o3.Irreps("16x0e"),
)


model = torch.compile(model, fullgraph=True)

model(graph.numbers,
args_in = (graph.y,
graph.relative_vectors,
graph.edge_index,
graph.num_nodes)


# model = torch.export.export(model,
# (graph.numbers,
# graph.relative_vectors,
# graph.edge_index,
# graph.num_nodes))
es = torch.export.export(model, args_in)
print(es)


142 changes: 101 additions & 41 deletions utils.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,116 @@
## Scatter mean function (courtesy of ChatGPT)
## Adapted from https://github.com/mir-group/pytorch_runstats

import torch
"""basic scatter operations from torch_scatter
Using code from https://github.com/rusty1s/pytorch_scatter, but cut down to avoid a dependency.
def scatter_mean(input, index=None, dim=None):
if index is not None:
# Case 1: Index is specified
output_size = index.max().tolist() + 1
output = torch.zeros(output_size, input.size(1), device=input.device)
n = torch.zeros(output_size, device=input.device)
PyTorch plans to move these features into the main repo, but until then,
to make installation simpler, we need this pure python set of wrappers
that don't require installing PyTorch C++ extensions.
See https://github.com/pytorch/pytorch/issues/63780.
"""

for i in range(input.size(0)):
idx = index[i]
n[idx] += 1
output[idx] += (input[i] - output[idx]) / n[idx]
from typing import Optional

return output
import torch

elif dim is not None:
# Case 2: Index is skipped, output_dim is specified
output = torch.zeros(len(dim), input.size(1), device=input.device)

start_idx = 0
for i, dim in enumerate(dim):
end_idx = start_idx + dim
if dim > 0:
segment_sum = input[start_idx:end_idx].sum(dim=0)
output[i] = segment_sum / dim
start_idx = end_idx
def _broadcast(src: torch.Tensor, other: torch.Tensor, dim: int):
if dim < 0:
dim = other.dim() + dim
if src.dim() == 1:
for _ in range(0, dim):
src = src.unsqueeze(0)
for _ in range(src.dim(), other.dim()):
src = src.unsqueeze(-1)
src = src.expand_as(other)
return src

return output

def scatter(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
reduce: str = "sum",
) -> torch.Tensor:
assert reduce == "sum" # for now, TODO
index = _broadcast(index, src, dim)
if out is None:
size = list(src.size())
if dim_size is not None:
size[dim] = dim_size
elif index.numel() == 0:
size[dim] = 0
else:
size[dim] = int(index.max()) + 1
out = torch.zeros(size, dtype=src.dtype, device=src.device)
return out.scatter_add_(dim, index, src)
else:
raise ValueError("Either 'index' or 'dim' must be specified.")
return out.scatter_add_(dim, index, src)


def scatter_std(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
unbiased: bool = True,
) -> torch.Tensor:

if out is not None:
dim_size = out.size(dim)

if dim < 0:
dim = src.dim() + dim

count_dim = dim
if index.dim() <= dim:
count_dim = index.dim() - 1

# # Example usage for Case 1 (index specified):
# input1 = torch.randn(3000, 144)
# index1 = torch.randint(0, 1000, (3000,))
# output1 = scatter_mean(input1, index=index1)
# print("Output shape (Case 1):", output1.shape)
ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter(ones, index, count_dim, dim_size=dim_size)

# # Example usage for Case 2 (index skipped, output_dim specified):
# input2 = torch.randn(3000, 144)
# output_dim = [3000]
# output2 = scatter_mean(input2, dim=output_dim)
# print("Output shape (Case 2):", output2.shape)
index = _broadcast(index, src, dim)
tmp = scatter(src, index, dim, dim_size=dim_size)
count = _broadcast(count, tmp, dim).clamp(1)
mean = tmp.div(count)

# # Example usage for Case 3 (both spe):
# input = torch.randn(3000, 144)
# index = torch.randint(0, 1000, (3000,))
# output_dim = [3000]
var = src - mean.gather(dim, index)
var = var * var
out = scatter(var, index, dim, out, dim_size)

# output = scatter_mean(input, index, output_dim)
# print(output.size()) # Should print torch.Size([1000, 144])
if unbiased:
count = count.sub(1).clamp_(1)
out = out.div(count + 1e-6).sqrt()

return out


def scatter_mean(
src: torch.Tensor,
index: torch.Tensor,
dim: int = -1,
out: Optional[torch.Tensor] = None,
dim_size: Optional[int] = None,
) -> torch.Tensor:
out = scatter(src, index, dim, out, dim_size)
dim_size = out.size(dim)

index_dim = dim
if index_dim < 0:
index_dim = index_dim + src.dim()
if index.dim() <= index_dim:
index_dim = index.dim() - 1

ones = torch.ones(index.size(), dtype=src.dtype, device=src.device)
count = scatter(ones, index, index_dim, None, dim_size)
count[count < 1] = 1
count = _broadcast(count, out, dim)
if out.is_floating_point():
out.true_divide_(count)
else:
out.div_(count, rounding_mode="floor")
return out

0 comments on commit 5c01a0c

Please sign in to comment.