Skip to content

Commit

Permalink
still facing the scatter_mean issue
Browse files Browse the repository at this point in the history
  • Loading branch information
mitkotak committed Jul 25, 2024
1 parent 15fba59 commit d46b42a
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 39 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
*__pycache__
*.venv
11 changes: 7 additions & 4 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,25 +64,25 @@ def __init__(self,

# Currently hardcoding 1 layer

print("node_features_irreps", node_features_irreps)
# print("node_features_irreps", node_features_irreps)

self.tp = o3.FullTensorProduct(relative_vectors_irreps.regroup(),
node_features_irreps.regroup(),
filter_ir_out=[o3.Irrep(f"{l}e") for l in range(self.lmax+1)] + [o3.Irrep(f"{l}o") for l in range(self.lmax+1)])
self.linear = o3.Linear(irreps_in=self.tp.irreps_out.regroup(), irreps_out=self.tp.irreps_out.regroup())
print("TP+Linear", self.linear.irreps_out)
# print("TP+Linear", self.linear.irreps_out)
self.mlp = MLP(input_dims = 1, # Since we are inputing the norms will always be (..., 1)
output_dims = self.tp.irreps_out.num_irreps)


self.elementwise_tp = o3.ElementwiseTensorProduct(o3.Irreps(f"{self.tp.irreps_out.num_irreps}x0e"), self.linear.irreps_out.regroup())
print("node feature broadcasted", self.elementwise_tp.irreps_out)
# print("node feature broadcasted", self.elementwise_tp.irreps_out)

# Poor mans filter function (Can already feel the judgement). Replicating irreps_array.filter("0e")
self.filter_tp = o3.FullTensorProduct(self.tp.irreps_out.regroup(), o3.Irreps("0e"), filter_ir_out=[o3.Irrep("0e")])
self.register_buffer("dummy_input", torch.ones(1))

print("aggregated node features", self.filter_tp.irreps_out)
# print("aggregated node features", self.filter_tp.irreps_out)

self.readout_mlp = MLP(input_dims = self.filter_tp.irreps_out.num_irreps,
output_dims = self.output_dims)
Expand Down Expand Up @@ -122,6 +122,9 @@ def forward(self,


# Aggregate the node features back.
# print("src", node_features_broadcasted.shape)
# print("index", receivers.shape)
# print("dim", node_features.shape[0])
node_features = scatter_mean(
node_features_broadcasted,
receivers,
Expand Down
23 changes: 11 additions & 12 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,16 @@

model = torch.compile(model, fullgraph=True)

# model(graph.numbers,
# graph.relative_vectors,
# graph.edge_index,
# graph.num_nodes)

# Currently turning off since Linear still needs weights
# Also need confirm that the model is working
model = torch.export.export(model,
(graph.numbers,
graph.relative_vectors,
graph.edge_index,
graph.num_nodes))
model(graph.numbers,
graph.relative_vectors,
graph.edge_index,
graph.num_nodes)


# model = torch.export.export(model,
# (graph.numbers,
# graph.relative_vectors,
# graph.edge_index,
# graph.num_nodes))


48 changes: 25 additions & 23 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,53 +2,55 @@

import torch


def scatter_mean(input, index=None, dim=None):
if index is not None:
# Case 1: Index is specified
output_size = index.max().tolist() + 1
output = torch.zeros(output_size, input.size(1), device=input.device)
n = torch.zeros(output_size, device=input.device)

for i in range(input.size(0)):
idx = index[i]
n[idx] += 1
output[idx] += (input[i] - output[idx]) / n[idx]

return output

elif dim is not None:
# Case 2: Index is skipped, output_dim is specified
output = torch.zeros(len(dim), input.size(1), device=input.device)

start_idx = 0
for i, dim in enumerate(dim):
end_idx = start_idx + dim
if dim > 0:
segment_sum = input[start_idx:end_idx].sum(dim=0)
output[i] = segment_sum / dim
start_idx = end_idx

return output

else:
raise ValueError("Either 'index' or 'dim' must be specified.")

# Example usage for Case 1 (index specified):
input1 = torch.randn(3000, 144)
index1 = torch.randint(0, 1000, (3000,))
output1 = scatter_mean(input1, index=index1)
print("Output shape (Case 1):", output1.shape)

# Example usage for Case 2 (index skipped, output_dim specified):
input2 = torch.randn(3000, 144)
output_dim = [3000]
output2 = scatter_mean(input2, dim=output_dim)
print("Output shape (Case 2):", output2.shape)

# Example usage for Case 3 (both spe):
input = torch.randn(3000, 144)
index = torch.randint(0, 1000, (3000,))
output_dim = [3000]

output = scatter_mean(input, index, output_dim)
print(output.size()) # Should print torch.Size([1000, 144])

# # Example usage for Case 1 (index specified):
# input1 = torch.randn(3000, 144)
# index1 = torch.randint(0, 1000, (3000,))
# output1 = scatter_mean(input1, index=index1)
# print("Output shape (Case 1):", output1.shape)

# # Example usage for Case 2 (index skipped, output_dim specified):
# input2 = torch.randn(3000, 144)
# output_dim = [3000]
# output2 = scatter_mean(input2, dim=output_dim)
# print("Output shape (Case 2):", output2.shape)

# # Example usage for Case 3 (both spe):
# input = torch.randn(3000, 144)
# index = torch.randint(0, 1000, (3000,))
# output_dim = [3000]

# output = scatter_mean(input, index, output_dim)
# print(output.size()) # Should print torch.Size([1000, 144])

0 comments on commit d46b42a

Please sign in to comment.