Skip to content

Commit

Permalink
Handling non-batched tensors in input TODO revisit this
Browse files Browse the repository at this point in the history
  • Loading branch information
JadenFiotto-Kaufman committed Feb 1, 2024
1 parent 53b1029 commit 5adc00c
Showing 1 changed file with 56 additions and 7 deletions.
63 changes: 56 additions & 7 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,22 @@ def value(self) -> Any:
return self.node.value


def concat(activations: Any, value: Any, batch_start: int, batch_size: int):
def concat(
activations: Any,
value: Any,
batch_start: int,
batch_size: int,
total_batch_size: int,
):
def _concat(values):
if isinstance(values[0], torch.Tensor):
return torch.concatenate(values)
# For same reason as we do total_batch_size
# TODO
orig_size = values[-1]
new_size = sum([value.shape[0] for value in values[:-1]])
if new_size == orig_size:
return torch.concatenate(values[:-1])
return values[0]
elif isinstance(values[0], list):
return [
_concat([value[value_idx] for value in values])
Expand All @@ -190,17 +202,34 @@ def _concat(values):
# As interventions are scoped only to their relevant batch, if we want to swap in values for this batch
# we need to concatenate the batches before and after the relevant batch with the new values.
# Getting batch data before.
pre = util.apply(activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor)

def narrow1(acts: torch.Tensor):
if total_batch_size == acts.shape[0]:
return acts.narrow(0, 0, batch_start)

return acts

def narrow2(acts: torch.Tensor):
if total_batch_size == acts.shape[0]:
return acts.narrow(0, post_batch_start, acts.shape[0] - post_batch_start)

return acts

pre = util.apply(activations, lambda x: narrow1(x), torch.Tensor)
post_batch_start = batch_start + batch_size
# Getting batch data after.
post = util.apply(
activations,
lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start),
lambda x: narrow2(x),
torch.Tensor,
)

# For same reason as we do total_batch_size
# TODO
orig_sizes = util.apply(activations, lambda x: x.shape[0], torch.Tensor)

# Concatenate
return _concat([pre, value, post])
return _concat([pre, value, post, orig_sizes])


def intervene(activations: Any, module_path: str, graph: Graph, key: str):
Expand Down Expand Up @@ -242,9 +271,27 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str):

# We set its result to the activations, indexed by only the relevant batch idxs.

# We find the max size of all shapes[0] and assume that is the total batch size.
# We then use this to NOT narrow tensors that does not have this size as their first dim.
# TODO maybe this isnt the right way to handle this. Maybe just check if multi invokes happen and if not, dont narrow.
total_batch_size = None

def narrow(acts: torch.Tensor):
nonlocal total_batch_size

_batch_size = acts.shape[0]

if total_batch_size is None or _batch_size > total_batch_size:
total_batch_size = _batch_size

if total_batch_size == _batch_size:
return acts.narrow(0, batch_start, batch_size)

return acts

value = util.apply(
activations,
lambda x: x.narrow(0, batch_start, batch_size),
lambda x: narrow(x),
torch.Tensor,
)

Expand All @@ -254,7 +301,9 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str):
# This would mean we want to replace activations for this batch with some other ones.
value = graph.get_swap(value)

activations = concat(activations, value, batch_start, batch_size)
activations = concat(
activations, value, batch_start, batch_size, total_batch_size
)

return activations

Expand Down

0 comments on commit 5adc00c

Please sign in to comment.