Skip to content

Commit

Permalink
Add ForwardContext args to wrapped signature
Browse files Browse the repository at this point in the history
  • Loading branch information
calpt committed Feb 21, 2025
1 parent 8f17f5e commit de309e4
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions src/adapters/wrappers/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from ..models import MODEL_MIXIN_MAPPING
from .configuration import init_adapters_config
from ..context import ForwardContext


SPECIAL_MODEL_TYPE_TO_MODULE_NAME = {
Expand Down Expand Up @@ -81,6 +82,11 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig]
if isinstance(base_model, ModelAdaptersMixin):
# HACK to preserve original forward method signature (e.g. for Trainer label names)
temp_signature = inspect.signature(model.forward.__func__)
params = list(temp_signature.parameters.values())
# add forward context args to signature
for param_name in ForwardContext.context_args:
params.append(inspect.Parameter(param_name, inspect.Parameter.KEYWORD_ONLY))
temp_signature = temp_signature.replace(parameters=params)
# Create new wrapper model class
model_class_name = model.__class__.__name__
model_class = type(
Expand Down

0 comments on commit de309e4

Please sign in to comment.