diff --git a/data.py b/data.py index fb35964..05f5de8 100644 --- a/data.py +++ b/data.py @@ -3,13 +3,13 @@ from torch_geometric.loader import DataLoader from torch_geometric.nn import radius_graph -def get_random_graph(nodes, cutoff) -> Data: +def get_random_graph(n_nodes, cutoff) -> Data: - positions = torch.randn(nodes, 3) + positions = torch.randn(n_nodes, 3) distance_matrix = positions[:, None, :] - positions[None, :, :] distance_matrix = torch.linalg.norm(distance_matrix, dim=-1) - assert distance_matrix.shape == (nodes, nodes) + assert distance_matrix.shape == (n_nodes, n_nodes) senders, receivers = torch.nonzero(distance_matrix < cutoff).T @@ -20,11 +20,11 @@ def get_random_graph(nodes, cutoff) -> Data: # Create a PyTorch Geometric Data object graph = Data( - pos = positions, - relative_vectors = positions[receivers] - positions[senders], - y=z, - edge_index=edge_index, - num_nodes=len(positions) + pos = positions, # node positions + relative_vectors = positions[receivers] - positions[senders], # node relative positions + y=z, # graph label + edge_index=edge_index, # edge indices + numbers=torch.ones((len(positions),1)), # node features ) return graph diff --git a/inference.cpp b/inference.cpp index 34911c4..bb76d66 100644 --- a/inference.cpp +++ b/inference.cpp @@ -14,10 +14,21 @@ int main(int argc, char *argv[]) { runner = new torch::inductor::AOTIModelContainerRunnerCuda(model_path, 1); std::vector inputs = { // node_features - torch::randn({32,1}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)), + + torch::ones({32, 1}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)), + // pos - torch::randn({32,3}, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)), + torch::tensor({ + {0., 0., 0.}, {0., 0., 1.}, {1., 0., 0.}, {1., 1., 0.}, + {1., 1., 1.}, {1., 1., 2.}, {2., 1., 1.}, {2., 0., 1.}, + {0., 0., 0.}, {1., 0., 0.}, {0., 1., 0.}, {1., 1., 0.}, + {0., 0., 0.}, {0., 0., 1.}, {0., 0., 2.}, {0., 0., 3.}, + {0., 0., 0.}, {0., 0., 1.}, {0., 1., 0.}, {1., 0., 0.}, + {0., 0., 0.}, {0., 0., 1.}, {0., 0., 2.}, {0., 1., 0.}, + {0., 0., 0.}, {0., 0., 1.}, {0., 0., 2.}, {0., 1., 1.}, + {0., 0., 0.}, {1., 0., 0.}, {1., 1., 0.}, {2., 1., 0.} + }, torch::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA)), // edge_index torch::tensor({