Skip to content

Commit

Permalink
Merge pull request #69 from JadenFiotto-Kaufman/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Feb 1, 2024
2 parents 98979b3 + 9fdc376 commit 6ddbd1a
Show file tree
Hide file tree
Showing 3 changed files with 219 additions and 19 deletions.
15 changes: 3 additions & 12 deletions src/nnsight/contexts/Invoker.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,18 +90,9 @@ def next(self, increment: int = 1) -> None:
# .next() increases which generation idx the interventions happen.
self.tracer.generation_idx += increment

if self.scan:
# Run graph with singe token input.
self.inputs = self.tracer.model._prepare_inputs(
self.tracer.model._example_input(), *self.args, **self.kwargs
)
self.tracer.model._scan(
self.inputs, *self.tracer.args, **self.tracer.kwargs
)
else:
for name, module in self.tracer.model.meta_model.named_modules():
if isinstance(module, Module):
module.clear()
for name, module in self.tracer.model.meta_model.named_modules():
if isinstance(module, Module):
module.clear()

def save_all(self) -> Dict[str, Proxy]:
"""Saves the output of all modules and returns a dictionary of [module_path -> save proxy]
Expand Down
63 changes: 56 additions & 7 deletions src/nnsight/intervention.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,22 @@ def value(self) -> Any:
return self.node.value


def concat(activations: Any, value: Any, batch_start: int, batch_size: int):
def concat(
activations: Any,
value: Any,
batch_start: int,
batch_size: int,
total_batch_size: int,
):
def _concat(values):
if isinstance(values[0], torch.Tensor):
return torch.concatenate(values)
# For same reason as we do total_batch_size
# TODO
orig_size = values[-1]
new_size = sum([value.shape[0] for value in values[:-1]])
if new_size == orig_size:
return torch.concatenate(values[:-1])
return values[0]
elif isinstance(values[0], list):
return [
_concat([value[value_idx] for value in values])
Expand All @@ -190,17 +202,34 @@ def _concat(values):
# As interventions are scoped only to their relevant batch, if we want to swap in values for this batch
# we need to concatenate the batches before and after the relevant batch with the new values.
# Getting batch data before.
pre = util.apply(activations, lambda x: x.narrow(0, 0, batch_start), torch.Tensor)

def narrow1(acts: torch.Tensor):
if total_batch_size == acts.shape[0]:
return acts.narrow(0, 0, batch_start)

return acts

def narrow2(acts: torch.Tensor):
if total_batch_size == acts.shape[0]:
return acts.narrow(0, post_batch_start, acts.shape[0] - post_batch_start)

return acts

pre = util.apply(activations, lambda x: narrow1(x), torch.Tensor)
post_batch_start = batch_start + batch_size
# Getting batch data after.
post = util.apply(
activations,
lambda x: x.narrow(0, post_batch_start, x.shape[0] - post_batch_start),
lambda x: narrow2(x),
torch.Tensor,
)

# For same reason as we do total_batch_size
# TODO
orig_sizes = util.apply(activations, lambda x: x.shape[0], torch.Tensor)

# Concatenate
return _concat([pre, value, post])
return _concat([pre, value, post, orig_sizes])


def intervene(activations: Any, module_path: str, graph: Graph, key: str):
Expand Down Expand Up @@ -242,9 +271,27 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str):

# We set its result to the activations, indexed by only the relevant batch idxs.

# We find the max size of all shapes[0] and assume that is the total batch size.
# We then use this to NOT narrow tensors that does not have this size as their first dim.
# TODO maybe this isnt the right way to handle this. Maybe just check if multi invokes happen and if not, dont narrow.
total_batch_size = None

def narrow(acts: torch.Tensor):
nonlocal total_batch_size

_batch_size = acts.shape[0]

if total_batch_size is None or _batch_size > total_batch_size:
total_batch_size = _batch_size

if total_batch_size == _batch_size:
return acts.narrow(0, batch_start, batch_size)

return acts

value = util.apply(
activations,
lambda x: x.narrow(0, batch_start, batch_size),
lambda x: narrow(x),
torch.Tensor,
)

Expand All @@ -254,7 +301,9 @@ def intervene(activations: Any, module_path: str, graph: Graph, key: str):
# This would mean we want to replace activations for this batch with some other ones.
value = graph.get_swap(value)

activations = concat(activations, value, batch_start, batch_size)
activations = concat(
activations, value, batch_start, batch_size, total_batch_size
)

return activations

Expand Down
160 changes: 160 additions & 0 deletions src/nnsight/models/UnifiedTransformer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
from __future__ import annotations

from typing import Any, Dict, List, Union

import torch
from transformers import BatchEncoding, PreTrainedTokenizer
from transformer_lens import HookedTransformer, HookedTransformerConfig

from .LanguageModel import LanguageModel


class UnifiedTransformer(LanguageModel):
"""UnifiedTransformer is an nnsight wrapper around TransformerLens's HookedTransformer.
Inputs can be in the form of:
Prompt: (str)
Prompts: (List[str])
Batched prompts: (List[List[str]])
Tokenized prompt: (Union[List[int], torch.Tensor])
Tokenized prompts: (Union[List[List[int]], torch.Tensor])
Direct input: (Dict[str,Any])
TransformerLens processing arguments can be passed as kwargs to the constructor.
Pass `processing=False` to call `from_pretrained_no_processing` instead of `from_pretrained`.
Calls to generate pass arguments downstream to :func:`GenerationMixin.generate`
Attributes:
config (HookedTransformerConfig): HookedTransformer config file.
tokenizer (PreTrainedTokenizer): Tokenizer for LMs.
meta_model (HookedTransformer): Meta version of underlying auto model.
local_model (HookedTransformer): Local version of underlying HookedTransformer.
"""

def __init__(
self,
model: str,
device: str,
*args,
processing: bool = True,
**kwargs
) -> None:
if processing:
hooked_model = HookedTransformer.from_pretrained(model, *args, **kwargs)
else:
hooked_model = HookedTransformer.from_pretrained_no_processing(model, *args, **kwargs)

self.tokenizer = hooked_model.tokenizer
self.meta_model: HookedTransformer = None
self.local_model: HookedTransformer = None

super().__init__(hooked_model, tokenizer=self.tokenizer, *args, **kwargs)

self.config: HookedTransformerConfig = self.local_model.cfg
self.local_model.device = device

def update_meta(self):
super().__init__(self.local_model, tokenizer=self.tokenizer)
self.config: HookedTransformerConfig = self.local_model.cfg

def _tokenize(
self,
inputs: Union[
str,
List[str],
List[List[str]],
List[int],
List[List[int]],
torch.Tensor,
Dict[str, Any],
],
**kwargs,
):
if isinstance(inputs, BatchEncoding):
return inputs

if isinstance(inputs, str) or (
isinstance(inputs, list) and isinstance(inputs[0], int)
):
inputs = [inputs]

if isinstance(inputs, torch.Tensor) and inputs.ndim == 1:
inputs = inputs.unsqueeze(0)

if not isinstance(inputs[0], str):
inputs = [{"input_ids": ids} for ids in inputs]
return self.tokenizer.pad(inputs, return_tensors="pt", **kwargs)

return self.tokenizer(inputs, return_tensors="pt", padding=True, **kwargs)

def _prepare_inputs(
self,
inputs: Union[
str,
List[str],
List[List[str]],
List[int],
List[List[int]],
torch.Tensor,
Dict[str, Any],
BatchEncoding,
],
**kwargs,
) -> BatchEncoding:
if isinstance(inputs, dict):

new_inputs = dict()

tokenized_inputs = self._tokenize(inputs["input"], **kwargs)

new_inputs['input'] = tokenized_inputs['input_ids']

if "attention_mask" in inputs:
for ai, attn_mask in enumerate(inputs["attention_mask"]):
tokenized_inputs["attention_mask"][ai, -len(attn_mask) :] = attn_mask

new_inputs["attention_mask"] = tokenized_inputs["attention_mask"]

return BatchEncoding(new_inputs)

inputs = self._tokenize(inputs, **kwargs)

if "input_ids" in inputs:
inputs["input"] = inputs.pop("input_ids")

return inputs

def _batch_inputs(
self, prepared_inputs: BatchEncoding, batched_inputs: Dict
) -> torch.Tensor:
if batched_inputs is None:
batched_inputs = {"input": []}

if "attention_mask" in prepared_inputs:
batched_inputs["attention_mask"] = []

batched_inputs["input"].extend(prepared_inputs["input"])

if "attention_mask" in prepared_inputs:
batched_inputs["attention_mask"].extend(prepared_inputs["attention_mask"])

return batched_inputs, len(prepared_inputs["input"])

def _example_input(self) -> Dict[str, torch.Tensor]:
return BatchEncoding(
{"input": torch.tensor([[0]])}
)

def _generation(
self, prepared_inputs, *args, max_new_tokens: int = 1, **kwargs
) -> Any:

# HookedTransformer uses attention_mask in forward but not in generate.
if "attention_mask" in prepared_inputs:
prepared_inputs.pop("attention_mask")

return super()._generation(
prepared_inputs, *args, max_new_tokens=max_new_tokens, **kwargs
)

0 comments on commit 6ddbd1a

Please sign in to comment.