Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

CAN Attention: standard management of headsand out_channels #276

Merged
merged 1 commit into from
May 14, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions test/nn/cell/test_can.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def test_forward(self):
in_channels_1=2,
out_channels=2,
dropout=0.5,
heads=1,
heads=2,
n_layers=2,
att_lift=False,
pooling=True,
Expand All @@ -31,8 +31,8 @@ def test_forward(self):
).to_sparse()

x_0, x_1 = (
torch.tensor(x_0).float().to(device),
torch.tensor(x_1).float().to(device),
x_0.clone().detach().float().to(device),
x_1.clone().detach().float().to(device),
)
adjacency_1 = adjacency_1.float().to(device)
adjacency_2 = adjacency_1.float().to(device)
Expand Down
13 changes: 7 additions & 6 deletions test/nn/cell/test_can_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ class TestCANLayer:
def test_forward(self):
"""Test the forward method of CANLayer."""
in_channels = 7
out_channels = 64
out_channels = 66
dropout_values = [0.5, 0.7]
heads_values = [1, 3]
concat_values = [True, False]
Expand Down Expand Up @@ -65,12 +65,12 @@ def test_forward(self):
)
x_out = can_layer.forward(x_1, lower_neighborhood, upper_neighborhood)
if concat:
assert x_out.shape == (n_cells, out_channels * heads)
else:
assert x_out.shape == (n_cells, out_channels)
else:
assert x_out.shape == (n_cells, out_channels // heads)

# Test if there are no non-zero values in the neighborhood
heads = 1
heads = 3
concat_list = [True, False]
skip_connection = True

Expand All @@ -79,6 +79,7 @@ def test_forward(self):
can_layer = CANLayer(
in_channels=in_channels,
out_channels=out_channels,
heads=heads,
concat=concat,
skip_connection=skip_connection,
version=version,
Expand All @@ -89,9 +90,9 @@ def test_forward(self):
torch.zeros_like(upper_neighborhood),
)
if concat:
assert x_out.shape == (n_cells, out_channels * heads)
else:
assert x_out.shape == (n_cells, out_channels)
else:
assert x_out.shape == (n_cells, out_channels // heads)

def test_reset_parameters(self):
"""Test the reset_parameters method of CANLayer."""
Expand Down
4 changes: 2 additions & 2 deletions topomodelx/nn/cell/can.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def __init__(
for _ in range(n_layers - 1):
layers.append(
CANLayer(
in_channels=out_channels * heads,
in_channels=out_channels,
out_channels=out_channels,
dropout=dropout,
heads=heads,
Expand All @@ -110,7 +110,7 @@ def __init__(
layers.append(
PoolLayer(
k_pool=k_pool,
in_channels_0=out_channels * heads,
in_channels_0=out_channels,
signal_pool_activation=torch.nn.Sigmoid(),
readout=True,
**kwargs,
Expand Down
49 changes: 23 additions & 26 deletions topomodelx/nn/cell/can_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,9 +439,9 @@ def __init__(
self.dropout = dropout
self.add_self_loops = add_self_loops

self.lin = torch.nn.Linear(in_channels, heads * out_channels, bias=False)
self.att_weight_src = Parameter(torch.Tensor(1, heads, out_channels))
self.att_weight_dst = Parameter(torch.Tensor(1, heads, out_channels))
self.lin = torch.nn.Linear(in_channels, out_channels, bias=False)
self.att_weight_src = Parameter(torch.Tensor(1, heads, out_channels // heads))
self.att_weight_dst = Parameter(torch.Tensor(1, heads, out_channels // heads))

self.reset_parameters()

Expand All @@ -468,7 +468,7 @@ def message(self, x_source):
"""
# Compute the linear transformation on the source features
x_message = self.lin(x_source).view(
-1, self.heads, self.out_channels
-1, self.heads, self.out_channels // self.heads
) # (n_k_cells, H, C)

# compute the source and target messages
Expand Down Expand Up @@ -534,12 +534,13 @@ def forward(self, x_source, neighborhood):
# If there are no non-zero values in the neighborhood, then the neighborhood is empty. -> return zero tensor
if not neighborhood.values().nonzero().size(0) > 0 and self.concat:
return torch.zeros(
(x_source.shape[0], self.out_channels * self.heads),
(x_source.shape[0], self.out_channels),
device=x_source.device,
) # (n_k_cells, H * C)
if not neighborhood.values().nonzero().size(0) > 0 and not self.concat:
return torch.zeros(
(x_source.shape[0], self.out_channels), device=x_source.device
(x_source.shape[0], self.out_channels // self.heads),
device=x_source.device,
) # (n_k_cells, C)

# Add self-loops to the neighborhood matrix if necessary
Expand All @@ -559,9 +560,7 @@ def forward(self, x_source, neighborhood):

# if concat true, concatenate the messages for each head. Otherwise, average the messages for each head.
if self.concat:
return aggregated_message.view(
-1, self.heads * self.out_channels
) # (n_k_cells, H * C)
return aggregated_message.view(-1, self.out_channels) # (n_k_cells, H * C)

return aggregated_message.mean(dim=1) # (n_k_cells, C)

Expand Down Expand Up @@ -613,7 +612,7 @@ def __init__(
heads: int,
concat: bool,
att_activation: torch.nn.Module,
add_self_loops: bool = False,
add_self_loops: bool = True,
aggr_func: Literal["sum", "mean", "add"] = "sum",
initialization: Literal["xavier_uniform", "xavier_normal"] = "xavier_uniform",
share_weights: bool = False,
Expand All @@ -634,17 +633,13 @@ def __init__(

if share_weights:
self.lin_src = self.lin_dst = torch.nn.Linear(
in_channels, heads * out_channels, bias=False
in_channels, out_channels, bias=False
)
else:
self.lin_src = torch.nn.Linear(
in_channels, heads * out_channels, bias=False
)
self.lin_dst = torch.nn.Linear(
in_channels, heads * out_channels, bias=False
)
self.lin_src = torch.nn.Linear(in_channels, out_channels, bias=False)
self.lin_dst = torch.nn.Linear(in_channels, out_channels, bias=False)

self.att_weight = Parameter(torch.Tensor(1, heads, out_channels))
self.att_weight = Parameter(torch.Tensor(1, heads, out_channels // heads))

self.reset_parameters()

Expand All @@ -671,12 +666,12 @@ def message(self, x_source):
"""
# Compute the linear transformation on the source features
x_src_message = self.lin_src(x_source).view(
-1, self.heads, self.out_channels
-1, self.heads, self.out_channels // self.heads
) # (n_k_cells, H, C)

# Compute the linear transformation on the source features
x_dst_message = self.lin_dst(x_source).view(
-1, self.heads, self.out_channels
-1, self.heads, self.out_channels // self.heads
) # (n_k_cells, H, C)

# Get the source and target projections of the neighborhood
Expand Down Expand Up @@ -737,12 +732,13 @@ def forward(self, x_source, neighborhood):
# If there are no non-zero values in the neighborhood, then the neighborhood is empty. -> return zero tensor
if not neighborhood.values().nonzero().size(0) > 0 and self.concat:
return torch.zeros(
(x_source.shape[0], self.out_channels * self.heads),
(x_source.shape[0], self.out_channels),
device=x_source.device,
) # (n_k_cells, H * C)
if not neighborhood.values().nonzero().size(0) > 0 and not self.concat:
return torch.zeros(
(x_source.shape[0], self.out_channels), device=x_source.device
(x_source.shape[0], self.out_channels // self.heads),
device=x_source.device,
) # (n_k_cells, C)

# Add self-loops to the neighborhood matrix if necessary
Expand All @@ -762,9 +758,7 @@ def forward(self, x_source, neighborhood):

# if concat true, concatenate the messages for each head. Otherwise, average the messages for each head.
if self.concat:
return aggregated_message.view(
-1, self.heads * self.out_channels
) # (n_k_cells, H * C)
return aggregated_message.view(-1, self.out_channels) # (n_k_cells, H * C)

return aggregated_message.mean(dim=1) # (n_k_cells, C)

Expand Down Expand Up @@ -836,6 +830,9 @@ def __init__(
assert in_channels > 0, ValueError("Number of input channels must be > 0")
assert out_channels > 0, ValueError("Number of output channels must be > 0")
assert heads > 0, ValueError("Number of heads must be > 0")
assert out_channels % heads == 0, ValueError(
"Number of output channels must be divisible by the number of heads"
)
assert dropout >= 0.0 and dropout <= 1.0, ValueError("Dropout must be in [0,1]")

# assert that shared weight is True only if version is v2
Expand Down Expand Up @@ -893,7 +890,7 @@ def __init__(

# linear transformation
if skip_connection:
out_channels = out_channels * heads if concat else out_channels
out_channels = out_channels if concat else out_channels // heads
self.lin = Linear(in_channels, out_channels, bias=False)
self.eps = 1 + 1e-6

Expand Down
Loading
Loading