From 0acaf0af2977a12778a2eb97cbb9abf87e95afa8 Mon Sep 17 00:00:00 2001 From: Francois Ledoyen Date: Mon, 17 Feb 2025 17:50:46 +0100 Subject: [PATCH] fix hf model generate: add forward context args in forward signature --- src/adapters/context.py | 26 ++++++++++++++++++++++++++ src/adapters/wrappers/model.py | 3 ++- 2 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/adapters/context.py b/src/adapters/context.py index 4dc0f3d1e..31b0fd5bf 100644 --- a/src/adapters/context.py +++ b/src/adapters/context.py @@ -1,4 +1,5 @@ import functools +import inspect import threading from typing import ContextManager @@ -121,6 +122,31 @@ def _call_forward(self, model, f, *args, **kwargs): return results + @classmethod + def add_contex_args_in_signature(cls, f): + old_signature = inspect.signature(f) + params = list(old_signature.parameters.values()) + # search if a VAR_POSITIONAL or VAR_KEYWORD is present + # if yes insert step parameter before it, else insert it in last position + param_types = [param.kind for param in params] + i = min( + [ + (param_types.index(param_type) if param_type in param_types else float("inf")) + for param_type in ( + inspect.Parameter.VAR_POSITIONAL, + inspect.Parameter.VAR_KEYWORD, + ) + ] + + [len(params)] + ) + for name in cls.context_args: + new_param = inspect.Parameter(name, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None) + if new_param not in params: + params.insert(i, new_param) + # we can now build the signature for the wrapper function + new_signature = old_signature.replace(parameters=params) + return new_signature + @classmethod def wrap_base(cls, f): diff --git a/src/adapters/wrappers/model.py b/src/adapters/wrappers/model.py index 1f54e29ca..4d971cca7 100644 --- a/src/adapters/wrappers/model.py +++ b/src/adapters/wrappers/model.py @@ -5,6 +5,7 @@ from torch import nn +from adapters.context import ForwardContext from transformers import PreTrainedModel from transformers.models.auto.auto_factory import getattribute_from_module from transformers.models.auto.configuration_auto import model_type_to_module_name @@ -80,7 +81,7 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] base_model = getattr(model, model.base_model_prefix) if isinstance(base_model, ModelAdaptersMixin): # HACK to preserve original forward method signature (e.g. for Trainer label names) - temp_signature = inspect.signature(model.forward.__func__) + temp_signature = ForwardContext.add_contex_args_in_signature(model.forward.__func__) # Create new wrapper model class model_class_name = model.__class__.__name__ model_class = type(