From 17f35c70f82a8778c9965b3b0867d642f1479851 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Mon, 18 Mar 2024 12:22:52 -0400 Subject: [PATCH 1/8] Default device cpu for Mamba models --- src/nnsight/models/Mamba.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/nnsight/models/Mamba.py b/src/nnsight/models/Mamba.py index 283407a7..9d1e225e 100644 --- a/src/nnsight/models/Mamba.py +++ b/src/nnsight/models/Mamba.py @@ -23,7 +23,7 @@ class Mamba(LanguageModel): - def _load(self, repo_id: str, device='meta', **kwargs) -> PreTrainedModel: + def _load(self, repo_id: str, device='cpu', **kwargs) -> PreTrainedModel: config = MambaConfig(**load_config_hf(repo_id)) From 8ffbefe2e3ec273c6ade4ba2e994064f17caaa2c Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Mon, 18 Mar 2024 22:27:01 -0400 Subject: [PATCH 2/8] Patching of accelerate not handling 'meta; and Fake tensors in some cases. Maybe revisit --- src/nnsight/__init__.py | 303 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 303 insertions(+) diff --git a/src/nnsight/__init__.py b/src/nnsight/__init__.py index 218e7bb6..322b9404 100644 --- a/src/nnsight/__init__.py +++ b/src/nnsight/__init__.py @@ -1,5 +1,6 @@ from functools import wraps import os +from typing import Union import yaml import torch @@ -86,4 +87,306 @@ def noop(input: torch.Tensor, *args, **kwargs): DEFAULT_PATCHER.add(Patch(FakeTensor, noop_wrapper(FakeTensor.tolist), "tolist")) +import warnings + + # Hacky patch to get around the fact this init method has no handling for 'meta' tensors. +def autoamp_init( + self, + device_type: str, + dtype = None, + enabled: bool = True, + cache_enabled: Optional[bool] = None, +): + if torch._jit_internal.is_scripting(): + self._enabled = enabled + self.device = device_type + self.fast_dtype = dtype + # TODO: support get_autocast_gpu/cpu_dtype + assert dtype is not None + return + self.device = device_type + self.custom_backend_name = torch._C._get_privateuse1_backend_name() + + if self.device == "cuda": + self.fast_dtype = torch.get_autocast_gpu_dtype() + ### PATCH ### + elif self.device == "meta": + self.fast_dtype = torch.get_autocast_cpu_dtype() + ### PATCH ### + elif self.device == "cpu": + self.fast_dtype = torch.get_autocast_cpu_dtype() + elif self.device == "xpu": + self.fast_dtype = torch.xpu.get_autocast_xpu_dtype() # type: ignore[attr-defined] + elif self.device == "ipu": + self.fast_dtype = torch.get_autocast_ipu_dtype() # type: ignore[attr-defined] + elif self.device == "hpu": + self.fast_dtype = torch.hpu.get_autocast_hpu_dtype() # type: ignore[attr-defined] + elif self.device == "xla": + self.fast_dtype = torch.get_autocast_xla_dtype() # type: ignore[attr-defined] + elif self.device == self.custom_backend_name: + necessary_funcs = [ + "is_autocast_enabled", + "set_autocast_enabled", + "get_autocast_dtype", + "set_autocast_dtype", + "get_amp_supported_dtype", + ] + message = f"Tried to use AMP with the `{self.custom_backend_name}` backend, but the backend has not " + message += "registered a module or the module miss some necessary funcs. The backend should register " + message += "a module by `torch._register_device_module`, and the module must have these funcs: \n" + message += "`is_autocast_enabled() -> bool`, `set_autocast_enabled(bool) -> None`, " + message += "`get_autocast_dtype() -> torch.dtype`, `set_autocast_dtype(torch.dtype) " + message += ( + "-> None` and `get_amp_supported_dtype() -> List[torch.dtype]`. \n" + ) + + assert hasattr(torch, self.custom_backend_name), message + self.custom_device_mod = getattr(torch, self.custom_backend_name) + for func in necessary_funcs: + assert hasattr(self.custom_device_mod, func), ( + message + f"But the func `{func}` is missing. \n" + ) + + self.fast_dtype = self.custom_device_mod.get_autocast_dtype() + else: + raise RuntimeError( + f"User specified an unsupported autocast device_type '{self.device}'" + ) + self._cache_enabled = torch.is_autocast_cache_enabled() + if ( + enabled + and torch.cuda.amp.common.amp_definitely_not_available() + and self.device == "cuda" + ): + warnings.warn( + "User provided device_type of 'cuda', but CUDA is not available. Disabling" + ) + enabled = False + if dtype is not None: + self.fast_dtype = dtype + if cache_enabled is not None: + self._cache_enabled = cache_enabled + + if self.device == "cpu": + supported_dtype = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtype and enabled: + error_message = "In CPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "CPU Autocast only supports dtype of " + error_message += ( + ", ".join(str(dtype) for dtype in supported_dtype) + " currently." + ) + warnings.warn(error_message) + enabled = False + elif self.device == "xpu": + supported_dtype = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtype: + error_message = "In XPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "XPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." + warnings.warn(error_message) + enabled = False + elif self.device == "ipu": + supported_dtypes = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtypes: + error_message = "In IPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "IPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." + warnings.warn(error_message) + enabled = False + elif self.device == "hpu": + supported_dtype = [torch.bfloat16, torch.float16] + if self.fast_dtype not in supported_dtype: + error_message = "In HPU autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += "HPU Autocast only supports dtypes of torch.bfloat16 and torch.float16 currently." + warnings.warn(error_message) + enabled = False + elif self.device == self.custom_backend_name: + supported_dtype = self.custom_device_mod.get_amp_supported_dtype() + if self.fast_dtype not in supported_dtype: + error_message = f"In {self.custom_backend_name} autocast, but the target dtype is not supported. " + error_message += f"Disabling autocast.\n {self.custom_backend_name} Autocast only supports dtypes of " + error_message += ( + ", ".join(str(dtype) for dtype in supported_dtype) + " currently." + ) + warnings.warn(error_message) + enabled = False + elif self.device == "cuda": + if ( + enabled + and self.fast_dtype == torch.bfloat16 + and not torch.cuda.is_bf16_supported() + ): + raise RuntimeError( + "Current CUDA Device does not support bfloat16. Please switch dtype to float16." + ) + elif self.device == "xla": + supported_dtype = [torch.float16, torch.bfloat16] + if self.fast_dtype not in supported_dtype: + error_message = "In XLA autocast, but the target dtype is not supported. Disabling autocast.\n" + error_message += ( + "XLA Autocast only supports dtype of torch.bfloat16 currently." + ) + warnings.warn(error_message) + enabled = False + self._enabled = enabled + +from torch.amp.autocast_mode import autocast + +DEFAULT_PATCHER.add(Patch(autocast, autoamp_init, "__init__")) + +from accelerate.utils.modeling import is_npu_available, check_device_same + +# Hacky patch to get around this function trying to set the parameter of a non meta tensor to meta. +# Also handles FakeTensors. +def set_module_tensor_to_device( + module: torch.nn.Module, + tensor_name: str, + device: Union[int, str, torch.device], + value: Optional[torch.Tensor] = None, + dtype: Optional[Union[str, torch.dtype]] = None, + fp16_statistics: Optional[torch.HalfTensor] = None, +): + """ + A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing + `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function). + + Args: + module (`torch.nn.Module`): + The module in which the tensor we want to move lives. + tensor_name (`str`): + The full name of the parameter/buffer. + device (`int`, `str` or `torch.device`): + The device on which to set the tensor. + value (`torch.Tensor`, *optional*): + The value of the tensor (useful when going from the meta device to any other device). + dtype (`torch.dtype`, *optional*): + If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to + the dtype of the existing parameter in the model. + fp16_statistics (`torch.HalfTensor`, *optional*): + The list of fp16 statistics to set on the module, used for 8 bit model serialization. + """ + ### PATCH ### + if isinstance(device, str) and device == 'meta': + return + ### PATCH ### + # Recurse if needed + if "." in tensor_name: + splits = tensor_name.split(".") + for split in splits[:-1]: + new_module = getattr(module, split) + if new_module is None: + raise ValueError(f"{module} has no attribute {split}.") + module = new_module + tensor_name = splits[-1] + + if tensor_name not in module._parameters and tensor_name not in module._buffers: + raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.") + is_buffer = tensor_name in module._buffers + old_value = getattr(module, tensor_name) + + if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: + raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") + + if value is not None: + if old_value.shape != value.shape: + raise ValueError( + f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this look incorrect.' + ) + + if dtype is None: + # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model + value = value.to(old_value.dtype) + elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + value = value.to(dtype) + + param = module._parameters[tensor_name] if tensor_name in module._parameters else None + param_cls = type(param) + + device_quantization = None + with torch.no_grad(): + # leave it on cpu first before moving them to cuda + # # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0 + if ( + param is not None + and param.device.type != "cuda" + and torch.device(device).type == "cuda" + and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"] + ): + device_quantization = device + device = "cpu" + # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). + if is_npu_available() and isinstance(device, int): + device = f"npu:{device}" + if value is None: + new_value = old_value.to(device) + if dtype is not None and device in ["meta", torch.device("meta")]: + if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")): + new_value = new_value.to(dtype) + + if not is_buffer: + module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad) + elif isinstance(value, torch.Tensor): + new_value = value.to(device) + else: + new_value = torch.tensor(value, device=device) + if device_quantization is not None: + device = device_quantization + if is_buffer: + module._buffers[tensor_name] = new_value + elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device): + param_cls = type(module._parameters[tensor_name]) + kwargs = module._parameters[tensor_name].__dict__ + if param_cls.__name__ in ["Int8Params", "FP4Params"]: + if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32: + # downcast to fp16 if any - needed for 8bit serialization + new_value = new_value.to(torch.float16) + # quantize module that are going to stay on the cpu so that we offload quantized weights + if device == "cpu" and param_cls.__name__ == "Int8Params": + new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu") + new_value.CB = new_value.CB.to("cpu") + new_value.SCB = new_value.SCB.to("cpu") + else: + new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device) + ### PATCH ### + elif isinstance(old_value, FakeTensor): + new_value = new_value + ### PATCH ### + else: + new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device) + + module._parameters[tensor_name] = new_value + if fp16_statistics is not None: + setattr(module._parameters[tensor_name], "SCB", fp16_statistics.to(device)) + del fp16_statistics + # as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight + if ( + module.__class__.__name__ == "Linear8bitLt" + and getattr(module.weight, "SCB", None) is None + and str(module.weight.device) != "meta" + ): + # quantize only if necessary + device_index = torch.device(device).index if torch.device(device).type == "cuda" else None + if not getattr(module.weight, "SCB", None) and device_index is not None: + if module.bias is not None and module.bias.device.type != "meta": + # if a bias exists, we need to wait until the bias is set on the correct device + module = module.cuda(device_index) + elif module.bias is None: + # if no bias exists, we can quantize right away + module = module.cuda(device_index) + elif module.__class__.__name__ == "Linear4bit" and getattr(module.weight, "quant_state", None) is None: + # quantize only if necessary + device_index = torch.device(device).index if torch.device(device).type == "cuda" else None + if not getattr(module.weight, "quant_state", None) and device_index is not None: + module.weight = module.weight.cuda(device_index) + # clean pre and post foward hook + if is_npu_available(): + torch.npu.empty_cache() + else: + torch.cuda.empty_cache() + +from accelerate import hooks + + + +DEFAULT_PATCHER.add(Patch(hooks, set_module_tensor_to_device, "set_module_tensor_to_device")) + + DEFAULT_PATCHER.__enter__() From 2e569d1e939e9d76cc21d15553134ed46cddab13 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Mon, 18 Mar 2024 22:32:05 -0400 Subject: [PATCH 3/8] More hacky patching for accelerate. Revisit --- src/nnsight/__init__.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/nnsight/__init__.py b/src/nnsight/__init__.py index 322b9404..850ac0ab 100644 --- a/src/nnsight/__init__.py +++ b/src/nnsight/__init__.py @@ -348,10 +348,12 @@ def set_module_tensor_to_device( ### PATCH ### elif isinstance(old_value, FakeTensor): new_value = new_value + elif isinstance(new_value, FakeTensor): + new_value = new_value ### PATCH ### else: new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device) - + module._parameters[tensor_name] = new_value if fp16_statistics is not None: setattr(module._parameters[tensor_name], "SCB", fp16_statistics.to(device)) From 53278c4546c634f3fe386b7e9c392c82fda23f08 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Mon, 18 Mar 2024 23:01:39 -0400 Subject: [PATCH 4/8] Update last patch for newest accelerate version --- src/nnsight/__init__.py | 66 ++++++++++++++++++++++++++++++++--------- 1 file changed, 52 insertions(+), 14 deletions(-) diff --git a/src/nnsight/__init__.py b/src/nnsight/__init__.py index 850ac0ab..bd72f7e3 100644 --- a/src/nnsight/__init__.py +++ b/src/nnsight/__init__.py @@ -1,6 +1,6 @@ from functools import wraps import os -from typing import Union +from typing import Dict, Union import yaml import torch @@ -232,7 +232,7 @@ def autoamp_init( DEFAULT_PATCHER.add(Patch(autocast, autoamp_init, "__init__")) -from accelerate.utils.modeling import is_npu_available, check_device_same +from accelerate.utils.modeling import is_npu_available, check_device_same, is_xpu_available # Hacky patch to get around this function trying to set the parameter of a non meta tensor to meta. # Also handles FakeTensors. @@ -243,6 +243,7 @@ def set_module_tensor_to_device( value: Optional[torch.Tensor] = None, dtype: Optional[Union[str, torch.dtype]] = None, fp16_statistics: Optional[torch.HalfTensor] = None, + tied_params_map: Optional[Dict[int, Dict[torch.device, torch.Tensor]]] = None, ): """ A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing @@ -262,11 +263,11 @@ def set_module_tensor_to_device( the dtype of the existing parameter in the model. fp16_statistics (`torch.HalfTensor`, *optional*): The list of fp16 statistics to set on the module, used for 8 bit model serialization. + tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`): + A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given + execution device, this parameter is useful to reuse the first available pointer of a shared weight on the + device for all others, instead of duplicating memory. """ - ### PATCH ### - if isinstance(device, str) and device == 'meta': - return - ### PATCH ### # Recurse if needed if "." in tensor_name: splits = tensor_name.split(".") @@ -282,6 +283,24 @@ def set_module_tensor_to_device( is_buffer = tensor_name in module._buffers old_value = getattr(module, tensor_name) + # Treat the case where old_value (or a custom `value`, typically offloaded to RAM/disk) belongs to a tied group, and one of the weight + # in the tied group has already been dispatched to the device, by avoiding reallocating memory on the device and just copying the pointer. + if ( + value is not None + and tied_params_map is not None + and value.data_ptr() in tied_params_map + and device in tied_params_map[value.data_ptr()] + ): + module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device] + return + elif ( + tied_params_map is not None + and old_value.data_ptr() in tied_params_map + and device in tied_params_map[old_value.data_ptr()] + ): + module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device] + return + if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None: raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.") @@ -315,6 +334,8 @@ def set_module_tensor_to_device( # `torch.Tensor.to()` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)). if is_npu_available() and isinstance(device, int): device = f"npu:{device}" + if is_xpu_available() and isinstance(device, int): + device = f"xpu:{device}" if value is None: new_value = old_value.to(device) if dtype is not None and device in ["meta", torch.device("meta")]: @@ -345,18 +366,16 @@ def set_module_tensor_to_device( new_value.SCB = new_value.SCB.to("cpu") else: new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(device) - ### PATCH ### - elif isinstance(old_value, FakeTensor): - new_value = new_value - elif isinstance(new_value, FakeTensor): - new_value = new_value - ### PATCH ### + elif param_cls.__name__ in ["QTensor", "QBitsTensor"]: + new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device) + elif isinstance(new_value, FakeTensor) or isinstance(old_value, FakeTensor): + new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(device) else: new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(device) - + module._parameters[tensor_name] = new_value if fp16_statistics is not None: - setattr(module._parameters[tensor_name], "SCB", fp16_statistics.to(device)) + module._parameters[tensor_name].SCB = fp16_statistics.to(device) del fp16_statistics # as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight if ( @@ -381,8 +400,27 @@ def set_module_tensor_to_device( # clean pre and post foward hook if is_npu_available(): torch.npu.empty_cache() + elif is_xpu_available(): + torch.xpu.empty_cache() else: torch.cuda.empty_cache() + + # When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in + # order to avoid duplicating memory, see above. + if ( + tied_params_map is not None + and old_value.data_ptr() in tied_params_map + and device not in tied_params_map[old_value.data_ptr()] + ): + tied_params_map[old_value.data_ptr()][device] = new_value + elif ( + value is not None + and tied_params_map is not None + and value.data_ptr() in tied_params_map + and device not in tied_params_map[value.data_ptr()] + ): + tied_params_map[value.data_ptr()][device] = new_value + from accelerate import hooks From cd61f27840e5b3d07df7dbee1a386a06a4591f8a Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Tue, 19 Mar 2024 16:08:56 -0400 Subject: [PATCH 5/8] Temp --- src/nnsight/envoy.py | 27 ++++++++++----------------- src/nnsight/models/NNsightModel.py | 12 ++++++++++++ 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/src/nnsight/envoy.py b/src/nnsight/envoy.py index 781694db..3935e3e6 100644 --- a/src/nnsight/envoy.py +++ b/src/nnsight/envoy.py @@ -47,29 +47,22 @@ def __init__(self, module: torch.nn.Module, module_path: str = ""): self._hook, with_kwargs=True ) - if isinstance(module, torch.nn.ModuleList): - for i, module in enumerate(self._module): - envoy = Envoy(module, module_path=f"{self._module_path}.{i}") + for name, module in self._module.named_children(): - self._sub_envoys.append(envoy) + envoy = Envoy(module, module_path=f"{self._module_path}.{name}") - else: - for name, module in self._module.named_children(): - - envoy = Envoy(module, module_path=f"{self._module_path}.{name}") - - self._sub_envoys.append(envoy) + self._sub_envoys.append(envoy) - # If the module already has a sub-module named 'input' or 'output', - # mount the proxy access to 'nns_input' or 'nns_output instead. - if hasattr(Envoy, name): + # If the module already has a sub-module named 'input' or 'output', + # mount the proxy access to 'nns_input' or 'nns_output instead. + if hasattr(Envoy, name): - self._handle_overloaded_mount(envoy, name) + self._handle_overloaded_mount(envoy, name) - else: - - setattr(self, name, envoy) + else: + + setattr(self, name, envoy) def _handle_overloaded_mount(self, envoy: Envoy, mount_point: str): diff --git a/src/nnsight/models/NNsightModel.py b/src/nnsight/models/NNsightModel.py index 80899990..0bfe5535 100644 --- a/src/nnsight/models/NNsightModel.py +++ b/src/nnsight/models/NNsightModel.py @@ -6,6 +6,7 @@ import accelerate import torch from transformers import AutoConfig, AutoModel +from typing_extensions import Self from .. import util from ..contexts.Runner import Runner @@ -276,6 +277,17 @@ def dispatch_model(self, *args, **kwargs) -> None: logger.info(f"Dispatched `{self._model_key}`") + def to(self, *args, **kwargs) -> Self: + """Override torch.nn.Module.to so this returns the NNSight model, not the underlying module when doing: model = model.to(...) + + Returns: + Envoy: Envoy. + """ + + self._model = self._model.to(*args, **kwargs) + + return self + def __repr__(self) -> str: """Wrapper of ._model's representation as the NNsight model's representation. From 1f0847ce9db1175288d87b2e4111b9813bad171d Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Fri, 22 Mar 2024 13:39:38 -0400 Subject: [PATCH 6/8] Catch referencing _tracer in Envoy as it may be deallocated --- src/nnsight/envoy.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/nnsight/envoy.py b/src/nnsight/envoy.py index 3935e3e6..46677b58 100644 --- a/src/nnsight/envoy.py +++ b/src/nnsight/envoy.py @@ -228,6 +228,9 @@ def __repr__(self) -> str: extra_lines = extra_repr.split("\n") child_lines = [] for attribute_name, attribute in self.__dict__.items(): + + if attribute_name == '_tracer': + continue if isinstance(attribute, Envoy): From cca83ee81f3e1fa70830467f91f446d3b336e125 Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Sat, 23 Mar 2024 12:08:00 -0400 Subject: [PATCH 7/8] Handle the LanguageModel case where a pre loaded model is passed in. Need to add the generator module to it. --- src/nnsight/models/LanguageModel.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/nnsight/models/LanguageModel.py b/src/nnsight/models/LanguageModel.py index 4c15aabe..354aa929 100644 --- a/src/nnsight/models/LanguageModel.py +++ b/src/nnsight/models/LanguageModel.py @@ -138,6 +138,10 @@ def __init__( ) super().__init__(*args, **kwargs) + + if not hasattr(self._model, 'generator'): + + setattr(self._model, 'generator', WrapperModule()) def _load(self, repo_id: str, **kwargs) -> PreTrainedModel: @@ -154,8 +158,6 @@ def _load(self, repo_id: str, **kwargs) -> PreTrainedModel: model = self.automodel.from_config(config, trust_remote_code=True) - setattr(model, 'generator', WrapperModule()) - return model model = self.automodel.from_pretrained(repo_id, config=config, **kwargs) From bc405c30cf2ea47263c23f80d7b944de0078e95a Mon Sep 17 00:00:00 2001 From: "jadenfk@outlook.com" Date: Sat, 23 Mar 2024 12:15:48 -0400 Subject: [PATCH 8/8] Handle case in Envoy when applying a module that has no parameters. --- src/nnsight/envoy.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/src/nnsight/envoy.py b/src/nnsight/envoy.py index 46677b58..9995eb6d 100644 --- a/src/nnsight/envoy.py +++ b/src/nnsight/envoy.py @@ -47,7 +47,6 @@ def __init__(self, module: torch.nn.Module, module_path: str = ""): self._hook, with_kwargs=True ) - for name, module in self._module.named_children(): envoy = Envoy(module, module_path=f"{self._module_path}.{name}") @@ -61,7 +60,7 @@ def __init__(self, module: torch.nn.Module, module_path: str = ""): self._handle_overloaded_mount(envoy, name) else: - + setattr(self, name, envoy) def _handle_overloaded_mount(self, envoy: Envoy, mount_point: str): @@ -228,8 +227,8 @@ def __repr__(self) -> str: extra_lines = extra_repr.split("\n") child_lines = [] for attribute_name, attribute in self.__dict__.items(): - - if attribute_name == '_tracer': + + if attribute_name == "_tracer": continue if isinstance(attribute, Envoy): @@ -303,7 +302,15 @@ def __call__(self, *args: List[Any], **kwargs: Dict[str, Any]) -> InterventionPr module_proxy = getattr(self._tracer._graph.module_proxy, self._module_path) - torch.set_default_device(next(self._module.parameters()).device) + try: + + device = next(self._module.parameters()).device + + except: + + device = torch.device("cpu") + + torch.set_default_device(device) proxy = module_proxy.forward(*args, **kwargs)