diff --git a/src/adapters/wrappers/model.py b/src/adapters/wrappers/model.py index 1f54e29ca..1038e3ec5 100644 --- a/src/adapters/wrappers/model.py +++ b/src/adapters/wrappers/model.py @@ -18,6 +18,7 @@ ) from ..models import MODEL_MIXIN_MAPPING from .configuration import init_adapters_config +from ..context import ForwardContext SPECIAL_MODEL_TYPE_TO_MODULE_NAME = { @@ -81,6 +82,11 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] if isinstance(base_model, ModelAdaptersMixin): # HACK to preserve original forward method signature (e.g. for Trainer label names) temp_signature = inspect.signature(model.forward.__func__) + params = list(temp_signature.parameters.values()) + # add forward context args to signature + for param_name in ForwardContext.context_args: + params.append(inspect.Parameter(param_name, inspect.Parameter.KEYWORD_ONLY)) + temp_signature = temp_signature.replace(parameters=params) # Create new wrapper model class model_class_name = model.__class__.__name__ model_class = type(