Skip to content

Commit

Permalink
Merge pull request #1 from SamirMoustafa/dev/initial-dev-branch
Browse files Browse the repository at this point in the history
Quick fix for number of operations by scatter_add
  • Loading branch information
SamirMoustafa authored Mar 20, 2024
2 parents 55342cf + 19d4edd commit a05b958
Show file tree
Hide file tree
Showing 4 changed files with 12 additions and 32 deletions.
18 changes: 6 additions & 12 deletions examples/message_passing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import inspect
from collections import OrderedDict

from torch import is_tensor
from torch import is_tensor, zeros
from torch.nn import Module

msg_special_args = {"edge_index", "edge_index_i", "edge_index_j", "size", "size_i", "size_j"}
Expand Down Expand Up @@ -77,27 +77,21 @@ def forward(self, edge_index, size=None, **kwargs):

def aggregate(self, inputs, index, dim_size):
num_features = inputs.shape[1]
x = zeros((dim_size, num_features), dtype=inputs.dtype)
x.scatter_reduce_(self.node_dim, index.view(-1, 1).expand(-1, num_features), inputs, self.aggr)
return x
index = index.view(-1, 1).expand(-1, num_features) if inputs.dim() > 1 else index
return zeros((dim_size, num_features), dtype=inputs.dtype).scatter_reduce_(self.node_dim, index, inputs, self.aggr)

def __repr__(self):
return "{}(dtype={})".format(self.__class__.__name__, self.dtype)

def message(self, x_j, edge_weight=None):
raise NotImplementedError("Not implemented yet.")
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j


if __name__ == "__main__":
from torch import ones, rand, long, zeros, device
from torch import ones, rand, long, device

from torch_operation_counter import OperationsCounterMode


class NativeMessagePassing(MessagePassing):
def message(self, x_j, edge_weight=None):
return x_j if edge_weight is None else edge_weight.view(-1, 1) * x_j

device = device("cpu")

# Reddit dataset settings
Expand All @@ -122,7 +116,7 @@ def message(self, x_j, edge_weight=None):
edge_index = ones((2, e), dtype=long).to(device)
edge_weight = rand(e).to(device)

conv = NativeMessagePassing()
conv = MessagePassing()

with OperationsCounterMode() as ops_counter:
conv(edge_index, size=(n, n), x=x, edge_weight=edge_weight)
Expand Down
10 changes: 3 additions & 7 deletions examples/pyg_gcn.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,11 @@ def __init__(self, in_channels, hidden_channels, out_channels, num_layers):
self.convs = ModuleList()
for i in range(num_layers):
in_channel = in_channels if i == 0 else hidden_channels
out_channel = hidden_channels if i < num_layers - 1 else out_channels
self.convs.append(GCNConv(in_channel, out_channel, normalize=False))
self.fc = Linear(out_channels, out_channels)
self.convs.append(GCNConv(in_channel, hidden_channels, add_self_loops=False))

def forward(self, x, edge_index, edge_weight=None):
for i, conv in enumerate(self.convs):
x = conv(x, edge_index, edge_weight)
x = self.fc(x)
return x


Expand All @@ -46,7 +43,7 @@ def forward(self, x, edge_index, edge_weight=None):
# out_feature = 3

# model hyper-parameters
hidden_channels = 128
hidden_channels = 32

model = GCN(num_feature, hidden_channels, out_feature, 2).to(device)
x = randn(n, num_feature).to(device)
Expand All @@ -65,7 +62,6 @@ def forward(self, x, edge_index, edge_weight=None):

print()
# Print the total number of operations by two layers of GCN.
# The total number of operations should be similar to 19 GigaOPs as reported in the introduction section of the paper
# The total number of operations should be around 19 GigaOPs as reported in the introduction section of the paper
# Tailor, Shyam A. et al. “Degree-Quant: Quantization-Aware Training for Graph Neural Networks.”, ICLR 2021
print(f"{model.__class__.__name__}::Total operations: GigiaOP(s) {ops_counter.total_operations / 1e9}")
print(f"{model.__class__.__name__}::Total parameters: KiloParam(s) {sum([p.numel() for p in [*model.parameters()][:-1]]) / 1e3}")
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from setuptools import find_packages, setup

__version__ = "0.3.1"
__version__ = "0.3.2"

install_requires = [
"torch>=1.13.1",
Expand Down
14 changes: 2 additions & 12 deletions torch_operation_counter/counters.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,10 @@ def arange_ops(inputs: List[Any], outputs: List[Any]) -> Number:
return num_operations


def scatter_add_ops(inputs: List[Any], outputs: List[Any]) -> Number:
dim = inputs[1]
def scatter_add_ops(inputs: List[Any], outputs: List[Any]) -> int:
index = inputs[2]
src = inputs[3]
assert index.dim() == src.dim(), "index and src must have the same number of dimensions"
num_operations = 0
if dim < 0:
dim += src.dim()
for d in range(src.dim()):
if d == dim:
unique_indices, counts = index.select(dim, d).view(-1).unique(return_counts=True)
num_operations += counts.sum().item()
else:
num_operations += src[d].numel()
num_operations = src.numel() + index.numel()
return num_operations


Expand Down

0 comments on commit a05b958

Please sign in to comment.