From 096297a6c2074c0bd8ca79a73083472039424373 Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 2 Feb 2025 20:26:06 +0100 Subject: [PATCH] fix signature --- src/adapters/model_mixin.py | 2 -- src/adapters/wrappers/model.py | 4 ++++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index c4ec9514d5..8303191a52 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1703,7 +1703,6 @@ def post_embedding_forward(self, module, args, embedding_output): @ForwardContext.wrap_base def forward(self, *args, **kwargs): - print("base context: ", ForwardContext.get_context().__dict__) return super().forward(*args, **kwargs) @@ -2245,5 +2244,4 @@ def freeze_embeddings(self, freeze=True): @ForwardContext.wrap def forward(self, *args, **kwargs): - print("head context: ", ForwardContext.get_context().__dict__) return super().forward(*args, **kwargs) diff --git a/src/adapters/wrappers/model.py b/src/adapters/wrappers/model.py index 12ed79e122..1f54e29ca1 100644 --- a/src/adapters/wrappers/model.py +++ b/src/adapters/wrappers/model.py @@ -1,4 +1,5 @@ import importlib +import inspect import os from typing import Any, Optional, Type, Union @@ -78,6 +79,8 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] if hasattr(model, "base_model_prefix") and hasattr(model, model.base_model_prefix): base_model = getattr(model, model.base_model_prefix) if isinstance(base_model, ModelAdaptersMixin): + # HACK to preserve original forward method signature (e.g. for Trainer label names) + temp_signature = inspect.signature(model.forward.__func__) # Create new wrapper model class model_class_name = model.__class__.__name__ model_class = type( @@ -86,6 +89,7 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] {}, ) model.__class__ = model_class + model.forward.__func__.__signature__ = temp_signature # Finally, initialize adapters model.init_adapters(model.config, adapters_config)