Skip to content

Commit

Permalink
removed device=... from multi-dimensional indices as they will be mov…
Browse files Browse the repository at this point in the history
…ed to the GPU anyway and will not be cached
  • Loading branch information
loreloc committed Dec 16, 2024
1 parent ec75123 commit 3c934f0
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions cirkit/backend/torch/layers/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,12 @@ def forward(self, x: Tensor) -> Tensor:
x = x.squeeze(dim=3) # (F, C, B)
weight = self.weight()
if self.num_channels == 1:
idx_fold = torch.arange(self.num_folds, device=weight.device)
idx_fold = torch.arange(self.num_folds)
x = weight[:, :, 0][idx_fold[:, None], :, x[:, 0]]
x = self.semiring.map_from(x, SumProductSemiring)
else:
idx_fold = torch.arange(self.num_folds, device=weight.device)[:, None, None]
idx_channel = torch.arange(self.num_channels, device=weight.device)[None, :, None]
idx_fold = torch.arange(self.num_folds)[:, None, None]
idx_channel = torch.arange(self.num_channels)[None, :, None]
x = weight[idx_fold, :, idx_channel, x]
x = self.semiring.map_from(x, SumProductSemiring)
x = self.semiring.prod(x, dim=1)
Expand Down Expand Up @@ -434,11 +434,11 @@ def log_unnormalized_likelihood(self, x: Tensor) -> Tensor:
# logits: (F, K, C, N)
logits = torch.log(self.probs()) if self.logits is None else self.logits()
if self.num_channels == 1:
idx_fold = torch.arange(self.num_folds, device=logits.device)
idx_fold = torch.arange(self.num_folds)
x = logits[:, :, 0][idx_fold[:, None], :, x[:, 0]]
else:
idx_fold = torch.arange(self.num_folds, device=logits.device)[:, None, None]
idx_channel = torch.arange(self.num_channels, device=logits.device)[None, :, None]
idx_fold = torch.arange(self.num_folds)[:, None, None]
idx_channel = torch.arange(self.num_channels)[None, :, None]
x = torch.sum(logits[idx_fold, :, idx_channel, x], dim=1)
return x

Expand Down

0 comments on commit 3c934f0

Please sign in to comment.