diff --git a/src/nnsight/intervention.py b/src/nnsight/intervention.py index 2ce6b0c1..4ca5d7c9 100644 --- a/src/nnsight/intervention.py +++ b/src/nnsight/intervention.py @@ -164,10 +164,22 @@ def value(self) -> Any: return self.node.value -def concat(activations: Any, value: Any, batch_start: int, batch_size: int): +def concat( + activations: Any, + value: Any, + batch_start: int, + batch_size: int, + total_batch_size: int, +): def _concat(values): if isinstance(values[0], torch.Tensor): - return torch.concatenate(values) + # For same reason as we do total_batch_size + # TODO + orig_size = values[-1] + new_size = sum([value.shape[0] for value in values[:-1]]) + if new_size == orig_size: + return torch.concatenate(values[:-1]) + return values[0] elif isinstance(values[0], list): return [ _concat([value[value_idx] for value in values]) @@ -190,17 +202,34 @@ def _concat(values): # As interventions are scoped only to their relevant batch, if we want to swap in values for this batch # we need to concatenate the batches before and after the relevant batch with the new values. # Getting batch data before. - pre = util.apply(activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor) + + def narrow1(acts: torch.Tensor): + if total_batch_size == acts.shape[0]: + return acts.narrow(0, 0, batch_start) + + return acts + + def narrow2(acts: torch.Tensor): + if total_batch_size == acts.shape[0]: + return acts.narrow(0, post_batch_start, acts.shape[0] - post_batch_start) + + return acts + + pre = util.apply(activations, lambda x: narrow1(x), torch.Tensor) post_batch_start = batch_start + batch_size # Getting batch data after. post = util.apply( activations, - lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start), + lambda x: narrow2(x), torch.Tensor, ) + # For same reason as we do total_batch_size + # TODO + orig_sizes = util.apply(activations, lambda x: x.shape[0], torch.Tensor) + # Concatenate - return _concat([pre, value, post]) + return _concat([pre, value, post, orig_sizes]) def intervene(activations: Any, module_path: str, graph: Graph, key: str): @@ -242,9 +271,27 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str): # We set its result to the activations, indexed by only the relevant batch idxs. + # We find the max size of all shapes[0] and assume that is the total batch size. + # We then use this to NOT narrow tensors that does not have this size as their first dim. + # TODO maybe this isnt the right way to handle this. Maybe just check if multi invokes happen and if not, dont narrow. + total_batch_size = None + + def narrow(acts: torch.Tensor): + nonlocal total_batch_size + + _batch_size = acts.shape[0] + + if total_batch_size is None or _batch_size > total_batch_size: + total_batch_size = _batch_size + + if total_batch_size == _batch_size: + return acts.narrow(0, batch_start, batch_size) + + return acts + value = util.apply( activations, - lambda x: x.narrow(0, batch_start, batch_size), + lambda x: narrow(x), torch.Tensor, ) @@ -254,7 +301,9 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str): # This would mean we want to replace activations for this batch with some other ones. value = graph.get_swap(value) - activations = concat(activations, value, batch_start, batch_size) + activations = concat( + activations, value, batch_start, batch_size, total_batch_size + ) return activations