Skip to content

Commit

Permalink
made book keeping fold indices buffers
Browse files Browse the repository at this point in the history
  • Loading branch information
loreloc committed Dec 18, 2024
1 parent f979fc7 commit 16b2761
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 27 deletions.
22 changes: 12 additions & 10 deletions cirkit/backend/torch/circuits.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,24 @@ def lookup(
self, module_outputs: list[Tensor], *, in_graph: Tensor | None = None
) -> Iterator[tuple[TorchLayer | None, tuple]]:
# Loop through the entries and yield inputs
for entry in self._entries:
for entry in self:
layer = entry.module
in_layer_ids = entry.in_module_ids
in_fold_idx = entry.in_fold_idx
# Catch the case there are some inputs coming from other modules
if entry.in_module_ids:
(in_fold_idx,) = entry.in_fold_idx
(in_module_ids,) = entry.in_module_ids
if len(in_module_ids) == 1:
x = module_outputs[in_module_ids[0]]
if in_layer_ids:
in_fold_idx_h = in_fold_idx[0]
in_layer_ids_h = in_layer_ids[0]
if len(in_layer_ids_h) == 1:
x = module_outputs[in_layer_ids_h[0]]
else:
x = torch.cat([module_outputs[mid] for mid in in_module_ids], dim=0)
x = x[in_fold_idx]
yield entry.module, (x,)
x = torch.cat([module_outputs[mid] for mid in in_layer_ids_h], dim=0)
x = x[in_fold_idx_h]
yield layer, (x,)
continue

# Catch the case there are no inputs coming from other modules
# That is, we are gathering the inputs of input layers
layer = entry.module
assert isinstance(layer, TorchInputLayer)
if layer.num_variables:
if in_graph is None:
Expand Down
36 changes: 32 additions & 4 deletions cirkit/backend/torch/graph/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
from typing import Any, Protocol, TypeVar

import torch
from torch import Tensor, nn

from cirkit.utils.algorithms import DiAcyclicGraph, subgraph
Expand Down Expand Up @@ -136,24 +137,51 @@ def __init__(self, entries: list[AddressBookEntry]) -> None:
if len(out_fold_idx.shape) != 1:
raise ValueError("The output fold index tensor should be a 1-dimensional tensor")
super().__init__()
self._entries = entries
self._num_outputs = out_fold_idx.shape[0]
self._entry_modules: list[TorchModule | None] = [e.module for e in entries]
self._entry_in_module_ids: list[list[list[int]]] = [e.in_module_ids for e in entries]
# We register the book-keeping tensor indices as buffers.
# By doing so they are automatically transferred to the device
# This reduces CPU-device communications required to transfer these indices
#
# TODO: Perhaps this can be made more elegant in the future, if someone
# decides to introduce a nn.BufferList container in torch
self._entry_in_fold_idx_targets: list[list[str]] = []
for i, e in enumerate(entries):
self._entry_in_fold_idx_targets.append([])
for j, fi in enumerate(e.in_fold_idx):
in_fold_idx_target = f"_in_fold_idx_{i}_{j}"
self.register_buffer(in_fold_idx_target, fi)
self._entry_in_fold_idx_targets[-1].append(in_fold_idx_target)

def __len__(self) -> int:
"""Retrieve the length of the address book.
Returns:
The number of address book entries.
"""
return len(self._entries)
return len(self._entry_modules)

def __iter__(self) -> Iterator[AddressBookEntry]:
"""Retrieve an iterator over address book entries.
"""Retrieve an iterator over address book entries, i.e., a tuple consisting of
three objects: (i) the torch module to evaluate (it can be None if the entry
is needed to return the output of the computational graph); (ii) for each input
to the module (i.e., depending on the arity) we have the list of ids to the
outputs of other modules (it can be empty if the module is an input module); and
(iii) for each input to the module we have the fold indexing tensor, which
is used to retrieve the inputs to a module, even if they are folded modules.
Returns:
An iterator over address book entries.
"""
return iter(self._entries)
for module, in_module_ids_hs, in_fold_idx_targets in zip(
self._entry_modules, self._entry_in_module_ids, self._entry_in_fold_idx_targets
):
yield AddressBookEntry(
module,
in_module_ids_hs,
[self.get_buffer(target) for target in in_fold_idx_targets],
)

@property
def num_outputs(self) -> int:
Expand Down
12 changes: 6 additions & 6 deletions cirkit/backend/torch/layers/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,12 +281,12 @@ def forward(self, x: Tensor) -> Tensor:
x = x.squeeze(dim=3) # (F, C, B)
weight = self.weight()
if self.num_channels == 1:
idx_fold = torch.arange(self.num_folds)
idx_fold = torch.arange(self.num_folds, device=weight.device)
x = weight[:, :, 0][idx_fold[:, None], :, x[:, 0]]
x = self.semiring.map_from(x, SumProductSemiring)
else:
idx_fold = torch.arange(self.num_folds)[:, None, None]
idx_channel = torch.arange(self.num_channels)[None, :, None]
idx_fold = torch.arange(self.num_folds, device=weight.device)[:, None, None]
idx_channel = torch.arange(self.num_channels, device=weight.device)[None, :, None]
x = weight[idx_fold, :, idx_channel, x]
x = self.semiring.map_from(x, SumProductSemiring)
x = self.semiring.prod(x, dim=1)
Expand Down Expand Up @@ -434,11 +434,11 @@ def log_unnormalized_likelihood(self, x: Tensor) -> Tensor:
# logits: (F, K, C, N)
logits = torch.log(self.probs()) if self.logits is None else self.logits()
if self.num_channels == 1:
idx_fold = torch.arange(self.num_folds)
idx_fold = torch.arange(self.num_folds, device=logits.device)
x = logits[:, :, 0][idx_fold[:, None], :, x[:, 0]]
else:
idx_fold = torch.arange(self.num_folds)[:, None, None]
idx_channel = torch.arange(self.num_channels)[None, :, None]
idx_fold = torch.arange(self.num_folds, device=logits.device)[:, None, None]
idx_channel = torch.arange(self.num_channels, device=logits.device)[None, :, None]
x = torch.sum(logits[idx_fold, :, idx_channel, x], dim=1)
return x

Expand Down
15 changes: 8 additions & 7 deletions cirkit/backend/torch/parameters/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,20 +30,21 @@ def lookup(
self, module_outputs: list[Tensor], *, in_graph: Tensor | None = None
) -> Iterator[tuple[TorchParameterNode | None, tuple]]:
# Loop through the entries and yield inputs
for entry in self._entries:
in_module_ids = entry.in_module_ids

for entry in self:
node = entry.module
in_node_ids = entry.in_module_ids
in_fold_idx = entry.in_fold_idx
# Catch the case there are some inputs coming from other modules
if in_module_ids:
if in_node_ids:
x = tuple(
ParameterAddressBook._select_index(module_outputs, mids, in_idx)
for mids, in_idx in zip(in_module_ids, entry.in_fold_idx)
for mids, in_idx in zip(in_node_ids, in_fold_idx)
)
yield entry.module, x
yield node, x
continue

# Catch the case there are no inputs coming from other modules
yield entry.module, ()
yield node, ()

@staticmethod
def _select_index(node_outputs: list[Tensor], mids: list[int], idx: Tensor | None) -> Tensor:
Expand Down

0 comments on commit 16b2761

Please sign in to comment.