Skip to content

Commit

Permalink
Merge pull request #52 from JadenFiotto-Kaufman/dev
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
JadenFiotto-Kaufman authored Jan 17, 2024
2 parents 59b2f07 + 1713411 commit 0817a66
Show file tree
Hide file tree
Showing 5 changed files with 75 additions and 15 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ classifiers = [
"Operating System :: OS Independent",
]
dependencies = [
"transformers",
"transformers@git+https://github.com/huggingface/transformers",
"protobuf",
"python-socketio[client]",
"tokenizers>=0.13.0",
Expand Down
58 changes: 49 additions & 9 deletions src/nnsight/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,40 +74,79 @@ def noop(input: torch.Tensor, *args, **kwargs):

DEFAULT_PATCHER.add(Patch(torch.Tensor, noop_wrapper(torch.Tensor.cpu), "cpu"))


def onehot_wrapper(fn):
@wraps(fn)
def onehot(input: torch.Tensor, num_classes=-1):
if input.device.type == "meta":
return torch.zeros((*input.shape, num_classes), device='meta')
return torch.zeros((*input.shape, num_classes), device="meta")

else:
return fn(input, num_classes=num_classes)

return onehot


DEFAULT_PATCHER.add(Patch(torch.nn.functional, onehot_wrapper(torch.nn.functional.one_hot), "one_hot"))
DEFAULT_PATCHER.add(
Patch(torch.nn.functional, onehot_wrapper(torch.nn.functional.one_hot), "one_hot")
)


DEFAULT_PATCHER.add(Patch(torch.Tensor, noop_wrapper(torch.Tensor.tolist), "tolist"))


def meta_nonzero(input: torch.Tensor, *args, as_tuple=False, **kwargs):
output = torch.zeros((input.numel(), input.ndim), device="meta", dtype=torch.long)

if as_tuple:
return tuple([output[:, i] for i in range(input.ndim)])

return output


def where_wrapper(fn):
def meta_nonzero_wrapper(fn):
@wraps(fn)
def inner(input: torch.Tensor, *args, **kwargs):
if input.device.type == "meta":
return meta_nonzero(input, *args, **kwargs)

else:
return fn(input, *args, **kwargs)

return inner


DEFAULT_PATCHER.add(
Patch(torch.Tensor, meta_nonzero_wrapper(torch.Tensor.nonzero), "nonzero")
)


def meta_where_wrapper(fn):
@wraps(fn)
def where(input: torch.Tensor, *args, **kwargs):
if input.device.type == "meta":
return input.to(torch.int)
if len(args) > 0:
dtype = args[0].dtype if isinstance(args[0], torch.Tensor) else type(args[0])
return torch.zeros_like(input, dtype=input.dtype, device='meta')
return meta_nonzero(input, as_tuple=True)

else:
return fn(input, *args, **kwargs)

return where

DEFAULT_PATCHER.add(Patch(torch, where_wrapper(torch.where), "where"))

DEFAULT_PATCHER.add(Patch(torch.Tensor, noop_wrapper(torch.Tensor.tolist), "tolist"))
DEFAULT_PATCHER.add(Patch(torch, meta_where_wrapper(torch.where), "where"))


DEFAULT_PATCHER.__enter__()

from torch._meta_registrations import (_meta_lib_dont_use_me_use_register_meta,
aten, global_decomposition_table,
register_meta)
from torch._meta_registrations import (
_meta_lib_dont_use_me_use_register_meta,
aten,
global_decomposition_table,
register_meta,
)


# Function which "activates" the most recent meta registered function.
Expand All @@ -124,4 +163,5 @@ def activate_recent_meta():
def local_scalar_dense_meta(A):
return 0


activate_recent_meta()
18 changes: 15 additions & 3 deletions src/nnsight/contexts/Runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,8 @@ def run_local(self):
)

def run_server(self):
# Create the pydantic class for the request.

# Create the pydantic object for the request.
request = pydantics.RequestModel(
args=self.args,
kwargs=self.kwargs,
Expand Down Expand Up @@ -123,14 +123,18 @@ def blocking_response(data):

# If the status of the response is completed, update the local nodes that the user specified to save.
# Then disconnect and continue.

if response.status == pydantics.ResponseModel.JobStatus.COMPLETED:

# Create BytesIO object to store bytes received from server in.
result_bytes = io.BytesIO()
result_bytes.seek(0)

# Get result from result url using job id.
with requests.get(
url=f"https://{CONFIG.API.HOST}/result/{response.id}", stream=True
) as stream:

# Total size of incoming data.
total_size = float(stream.headers["Content-length"])

with tqdm(
Expand All @@ -139,22 +143,30 @@ def blocking_response(data):
unit_scale=True,
desc="Downloading result",
) as progress_bar:
for data in stream.iter_content(chunk_size=4000000):
# chunk_size=None so server determines chunk size.
for data in stream.iter_content(chunk_size=None):
progress_bar.update(len(data))
result_bytes.write(data)

# Move cursor to beginning of bytes.
result_bytes.seek(0)

# Decode bytes with pickle and then into pydantic object.
result = pydantics.ResultModel(**pickle.load(result_bytes))

# Close bytes
result_bytes.close()

# Set save data.
for name, value in result.saves.items():
self.graph.nodes[name].value = value

# Set output data.
self.output = result.output

# Disconnect and continue program.
sio.disconnect()

# Or if there was some error.
elif response.status == pydantics.ResponseModel.JobStatus.ERROR:
sio.disconnect()
Expand Down
9 changes: 8 additions & 1 deletion src/nnsight/models/LanguageModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def _prepare_inputs(
if isinstance(inputs, dict):
_inputs = self._tokenize(inputs["input_ids"])

_inputs = self._tokenize(_inputs)
for ai, attn_mask in enumerate(inputs['attention_mask']):

_inputs['attention_mask'][ai, -len(attn_mask):] = attn_mask

if "labels" in inputs:
labels = self._tokenize(inputs["labels"])
Expand All @@ -140,10 +142,15 @@ def _batch_inputs(
if "labels" in prepared_inputs:
batched_inputs["labels"] = []

if "attention_mask" in prepared_inputs:
batched_inputs["attention_mask"] = []

batched_inputs["input_ids"].extend(prepared_inputs["input_ids"])

if "labels" in prepared_inputs:
batched_inputs["labels"].extend(prepared_inputs["labels"])
if "attention_mask" in prepared_inputs:
batched_inputs["attention_mask"].extend(prepared_inputs["attention_mask"])

return batched_inputs, len(prepared_inputs["input_ids"])

Expand Down
3 changes: 2 additions & 1 deletion src/nnsight/models/NNsightModel.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,7 +382,8 @@ def _scan(self, prepared_inputs, *args, **kwargs) -> None:
Args:
prepared_inputs (Any): Prepared inputs.
"""
return self.meta_model(**prepared_inputs.copy().to("meta"))
with accelerate.init_empty_weights(include_buffers=True):
return self.meta_model(**prepared_inputs.copy().to("meta"))

def _forward(self, prepared_inputs, *args, **kwargs) -> Any:
"""
Expand Down

0 comments on commit 0817a66

Please sign in to comment.