From 6424db85574b855c39846b6615a84761f0e7e557 Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 2 Feb 2025 18:47:35 +0100 Subject: [PATCH 1/8] Wrap ForwardContext around full model forward --- src/adapters/context.py | 64 +++++++++++-------- src/adapters/heads/model_mixin.py | 6 +- src/adapters/model_mixin.py | 13 ++-- src/adapters/models/albert/adapter_model.py | 14 +--- src/adapters/models/bart/adapter_model.py | 12 +--- src/adapters/models/beit/adapter_model.py | 14 +--- src/adapters/models/bert/adapter_model.py | 13 +--- .../models/bert_generation/adapter_model.py | 13 +--- src/adapters/models/clip/adapter_model.py | 14 +--- src/adapters/models/deberta/adapter_model.py | 13 +--- .../models/deberta_v2/adapter_model.py | 13 +--- .../models/distilbert/adapter_model.py | 13 +--- src/adapters/models/electra/adapter_model.py | 14 +--- src/adapters/models/gpt2/adapter_model.py | 13 +--- src/adapters/models/gptj/adapter_model.py | 13 +--- src/adapters/models/llama/adapter_model.py | 13 +--- src/adapters/models/mbart/adapter_model.py | 12 +--- src/adapters/models/mistral/adapter_model.py | 13 +--- src/adapters/models/mt5/adapter_model.py | 12 +--- src/adapters/models/plbart/adapter_model.py | 13 +--- src/adapters/models/roberta/adapter_model.py | 13 +--- src/adapters/models/t5/adapter_model.py | 12 +--- src/adapters/models/vit/adapter_model.py | 11 +--- src/adapters/models/whisper/adapter_model.py | 13 +--- .../models/xlm_roberta/adapter_model.py | 13 +--- src/adapters/models/xmod/adapter_model.py | 13 +--- 26 files changed, 121 insertions(+), 259 deletions(-) diff --git a/src/adapters/context.py b/src/adapters/context.py index db09b8918f..df0ae91f45 100644 --- a/src/adapters/context.py +++ b/src/adapters/context.py @@ -79,16 +79,22 @@ class ForwardContext(ContextManager): # thread-local storage that holds a stack of active contexts storage = threading.local() + context_args = [ + "output_adapter_gating_scores", + "output_adapter_fusion_attentions", + "adapter_input_parallelized", + ] context_attributes = [ "adapter_gating_scores", "adapter_fusion_attentions", - "adapter_input_parallelized", ] # Additional used attributes not exposed to the user # - prompt_tokens_length: length of the prompt tokens def __init__(self, model, *args, **kwargs): # If the model has a method ``forward_context()``, use it to create the context. + for arg_name in self.context_args: + setattr(self, arg_name, kwargs.pop(arg_name, None)) if hasattr(model, "forward_context"): model.forward_context(self, *args, **kwargs) @@ -99,6 +105,36 @@ def __enter__(self): def __exit__(self, type, value, traceback): ForwardContext.get_contexts().pop() + def _call_forward(self, model, f, *args, **kwargs): + kwargs = {k: v for k, v in kwargs.items() if k not in self.context_args} + results = f(model, *args, **kwargs) + + # append output attributes + if isinstance(results, tuple): + for attr in self.context_attributes: + if getattr(self, "output_" + attr, False): + results = results + (dict(getattr(self, attr)),) + else: + for attr in self.context_attributes: + if getattr(self, "output_" + attr, False): + results[attr] = dict(getattr(self, attr)) + + return results + + @classmethod + def wrap_base(cls, f): + + @functools.wraps(f) + def wrapper_func(self, *args, **kwargs): + if self.adapters_config is not None and ForwardContext.get_context() is None: + with cls(self, *args, **kwargs) as ctx: + results = ctx._call_forward(self, f, *args, **kwargs) + return results + else: + return f(self, *args, **kwargs) + + return wrapper_func + @classmethod def wrap(cls, f): """ @@ -109,30 +145,8 @@ def wrap(cls, f): def wrapper_func(self, *args, **kwargs): if self.adapters_config is not None: with cls(self, *args, **kwargs) as ctx: - # whether to output the context attributes - output_context = kwargs.pop("output_context", False) - kwargs = { - k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes - } - results = f(self, *args, **kwargs) - - # append output attributes - if isinstance(results, tuple): - for attr in cls.context_attributes: - if getattr(ctx, "output_" + attr, False): - results = results + (dict(getattr(ctx, attr)),) - else: - for attr in cls.context_attributes: - if getattr(ctx, "output_" + attr, False): - results[attr] = dict(getattr(ctx, attr)) - - if output_context: - context_dict = ctx.__dict__ - - if output_context: - return results, context_dict - else: - return results + results = ctx._call_forward(self, f, *args, **kwargs) + return results else: return f(self, *args, **kwargs) diff --git a/src/adapters/heads/model_mixin.py b/src/adapters/heads/model_mixin.py index ced4ff0753..ccc729f7b2 100644 --- a/src/adapters/heads/model_mixin.py +++ b/src/adapters/heads/model_mixin.py @@ -584,8 +584,12 @@ def _get_head_input(outputs, cls_out, batch): kwargs["invertible_adapter"] = inv_adapter # Set prompt tokens length + context = context or ForwardContext.get_context() if context is not None: - prompt_tokens_length = context.get("prompt_tokens_length", None) + if isinstance(context, ForwardContext): + prompt_tokens_length = getattr(context, "prompt_tokens_length", None) + else: + prompt_tokens_length = context.get("prompt_tokens_length", None) if prompt_tokens_length is not None: kwargs["prompt_tokens_length"] = prompt_tokens_length diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index 1895671f8d..c4ec9514d5 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1138,8 +1138,7 @@ def forward_context(self, context: ForwardContext, *args, **kwargs): context.adapters_parallelized = False # Check if already parallelized in encoder - adapter_input_parallelized = kwargs.pop("adapter_input_parallelized", None) - if adapter_input_parallelized: + if context.adapter_input_parallelized: if active_adapters.parallel_channels > 1: context.adapters_parallelized = True # Add the shared parameters for the active adapters to the context @@ -1165,8 +1164,6 @@ def forward_context(self, context: ForwardContext, *args, **kwargs): context.offsets = attention_mask.argmax(1) # Adapter gating and attention outputs - context.output_adapter_gating_scores = kwargs.get("output_adapter_gating_scores", False) - context.output_adapter_fusion_attentions = kwargs.get("output_adapter_fusion_attentions", False) context.adapter_gating_scores = defaultdict(dict) context.adapter_fusion_attentions = defaultdict(dict) @@ -1704,8 +1701,9 @@ def post_embedding_forward(self, module, args, embedding_output): return embedding_output - @ForwardContext.wrap + @ForwardContext.wrap_base def forward(self, *args, **kwargs): + print("base context: ", ForwardContext.get_context().__dict__) return super().forward(*args, **kwargs) @@ -2244,3 +2242,8 @@ def freeze_embeddings(self, freeze=True): else: for p in self.get_output_embeddings().parameters(): p.requires_grad = not freeze + + @ForwardContext.wrap + def forward(self, *args, **kwargs): + print("head context: ", ForwardContext.get_context().__dict__) + return super().forward(*args, **kwargs) diff --git a/src/adapters/models/albert/adapter_model.py b/src/adapters/models/albert/adapter_model.py index 73892bb2ff..c59ded93ab 100644 --- a/src/adapters/models/albert/adapter_model.py +++ b/src/adapters/models/albert/adapter_model.py @@ -6,7 +6,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...context import AdapterSetup +from ...context import AdapterSetup, ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -38,6 +38,7 @@ def __init__(self, config): self.init_weights() @add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -50,8 +51,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -66,7 +65,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.albert( + outputs = self.albert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -76,14 +75,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context - # BERT & RoBERTa & ALBERT return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/bart/adapter_model.py b/src/adapters/models/bart/adapter_model.py index 34a5615644..fe5f89d9d8 100644 --- a/src/adapters/models/bart/adapter_model.py +++ b/src/adapters/models/bart/adapter_model.py @@ -11,6 +11,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward +from ...context import ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -50,6 +51,7 @@ def get_decoder(self): return self.model.get_decoder() @add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -68,8 +70,6 @@ def forward( return_dict=None, past_key_values=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): r""" @@ -82,7 +82,7 @@ def forward( if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs: use_cache = False - outputs, context = self.model( + outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -98,13 +98,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context head_outputs = self.forward_head( outputs, diff --git a/src/adapters/models/beit/adapter_model.py b/src/adapters/models/beit/adapter_model.py index 578142ea11..834d21fdbf 100644 --- a/src/adapters/models/beit/adapter_model.py +++ b/src/adapters/models/beit/adapter_model.py @@ -10,7 +10,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...context import AdapterSetup +from ...context import AdapterSetup, ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...wrappers import init @@ -51,6 +51,7 @@ def make_inputs_require_grads(module, input, output): self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads) @add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING) + @ForwardContext.wrap def forward( self, pixel_values: Optional[torch.Tensor] = None, @@ -60,27 +61,18 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.beit( + outputs = self.beit( pixel_values, bool_masked_pos=bool_masked_pos, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context - # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/bert/adapter_model.py b/src/adapters/models/bert/adapter_model.py index 3be78bd5bd..abf7f88f5b 100644 --- a/src/adapters/models/bert/adapter_model.py +++ b/src/adapters/models/bert/adapter_model.py @@ -7,7 +7,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...context import AdapterSetup +from ...context import AdapterSetup, ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -43,6 +43,7 @@ def __init__(self, config): self.init_weights() @add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -55,8 +56,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -71,7 +70,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.bert( + outputs = self.bert( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -81,13 +80,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/bert_generation/adapter_model.py b/src/adapters/models/bert_generation/adapter_model.py index 0bbe5ad51f..eae5521f81 100644 --- a/src/adapters/models/bert_generation/adapter_model.py +++ b/src/adapters/models/bert_generation/adapter_model.py @@ -7,7 +7,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...context import AdapterSetup +from ...context import AdapterSetup, ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -38,6 +38,7 @@ def __init__(self, config): self.init_weights() @add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -53,8 +54,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -68,7 +67,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.bert( + outputs = self.bert( input_ids, attention_mask=attention_mask, position_ids=position_ids, @@ -81,13 +80,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/clip/adapter_model.py b/src/adapters/models/clip/adapter_model.py index 7734cd0212..19fbf57fae 100644 --- a/src/adapters/models/clip/adapter_model.py +++ b/src/adapters/models/clip/adapter_model.py @@ -10,7 +10,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...context import AdapterSetup +from ...context import AdapterSetup, ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -29,6 +29,7 @@ def __init__(self, config): self.post_init() @add_start_docstrings_to_model_forward(CLIP_INPUTS_DOCSTRING) + @ForwardContext.wrap def forward( self, input_ids: Optional[torch.LongTensor] = None, @@ -40,11 +41,9 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): - outputs, context = self.clip( + outputs = self.clip( input_ids=input_ids, pixel_values=pixel_values, attention_mask=attention_mask, @@ -53,14 +52,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context - if head or AdapterSetup.get_context_head_setup() or self.active_head: head_outputs = self.forward_head( outputs, diff --git a/src/adapters/models/deberta/adapter_model.py b/src/adapters/models/deberta/adapter_model.py index f5e15e8cb7..82f9fc5a82 100644 --- a/src/adapters/models/deberta/adapter_model.py +++ b/src/adapters/models/deberta/adapter_model.py @@ -1,7 +1,7 @@ from transformers.file_utils import add_start_docstrings from transformers.models.deberta.modeling_deberta import DebertaModel, DebertaPreTrainedModel -from ...context import AdapterSetup +from ...context import AdapterSetup, ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -32,6 +32,7 @@ def __init__(self, config): self.init_weights() + @ForwardContext.wrap def forward( self, input_ids=None, @@ -43,8 +44,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -59,7 +58,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.deberta( + outputs = self.deberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -68,13 +67,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/deberta_v2/adapter_model.py b/src/adapters/models/deberta_v2/adapter_model.py index 07092debdb..dfbfbfbc09 100644 --- a/src/adapters/models/deberta_v2/adapter_model.py +++ b/src/adapters/models/deberta_v2/adapter_model.py @@ -1,7 +1,7 @@ from transformers.file_utils import add_start_docstrings from transformers.models.deberta_v2.modeling_deberta_v2 import DebertaV2Model, DebertaV2PreTrainedModel -from ...context import AdapterSetup +from ...context import AdapterSetup, ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -34,6 +34,7 @@ def __init__(self, config): self.init_weights() + @ForwardContext.wrap def forward( self, input_ids=None, @@ -45,8 +46,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -61,7 +60,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.deberta( + outputs = self.deberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -70,13 +69,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/distilbert/adapter_model.py b/src/adapters/models/distilbert/adapter_model.py index d7b09dfe1e..63f633f730 100644 --- a/src/adapters/models/distilbert/adapter_model.py +++ b/src/adapters/models/distilbert/adapter_model.py @@ -9,6 +9,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward +from ...context import ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -63,6 +64,7 @@ def resize_position_embeddings(self, new_num_position_embeddings: int): self.distilbert.resize_position_embeddings(new_num_position_embeddings) @add_start_docstrings_to_model_forward(DISTILBERT_INPUTS_DOCSTRING.format("batch_size, num_choices")) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -73,8 +75,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -87,7 +87,7 @@ def forward( else None ) - distilbert_output, context = self.distilbert( + distilbert_output = self.distilbert( input_ids=input_ids, attention_mask=attention_mask, head_mask=head_mask, @@ -95,14 +95,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context - outputs = self.forward_head( distilbert_output, head_name=head, attention_mask=attention_mask, return_dict=return_dict, **kwargs ) diff --git a/src/adapters/models/electra/adapter_model.py b/src/adapters/models/electra/adapter_model.py index 83bc8f9184..108fb064a9 100644 --- a/src/adapters/models/electra/adapter_model.py +++ b/src/adapters/models/electra/adapter_model.py @@ -7,7 +7,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...context import AdapterSetup +from ...context import AdapterSetup, ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -43,6 +43,7 @@ def __init__(self, config): self.init_weights() @add_start_docstrings_to_model_forward(ELECTRA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -55,8 +56,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -71,7 +70,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.electra( + outputs = self.electra( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -81,14 +80,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context - head_inputs = outputs if head or AdapterSetup.get_context_head_setup() or self.active_head: diff --git a/src/adapters/models/gpt2/adapter_model.py b/src/adapters/models/gpt2/adapter_model.py index c6b96d1204..72386bd7d9 100644 --- a/src/adapters/models/gpt2/adapter_model.py +++ b/src/adapters/models/gpt2/adapter_model.py @@ -7,6 +7,7 @@ from transformers.utils import add_start_docstrings from ...composition import adjust_tensors_for_parallel +from ...context import ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -50,6 +51,7 @@ def __init__(self, config): self.model_parallel = False self.device_map = None + @ForwardContext.wrap def forward( self, input_ids=None, @@ -66,13 +68,11 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.transformer( + outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -86,14 +86,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context - batch_size = outputs[0].shape[0] if self.config.pad_token_id is None: diff --git a/src/adapters/models/gptj/adapter_model.py b/src/adapters/models/gptj/adapter_model.py index c075aeac1a..8549b34ace 100644 --- a/src/adapters/models/gptj/adapter_model.py +++ b/src/adapters/models/gptj/adapter_model.py @@ -7,6 +7,7 @@ from transformers.utils import add_start_docstrings from ...composition import adjust_tensors_for_parallel +from ...context import ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -50,6 +51,7 @@ def __init__(self, config): self.model_parallel = False self.device_map = None + @ForwardContext.wrap def forward( self, input_ids=None, @@ -64,13 +66,11 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.transformer( + outputs = self.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -82,14 +82,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context - batch_size = outputs[0].shape[0] if self.config.pad_token_id is None: diff --git a/src/adapters/models/llama/adapter_model.py b/src/adapters/models/llama/adapter_model.py index 39d93ad9b5..aea8e48a3e 100644 --- a/src/adapters/models/llama/adapter_model.py +++ b/src/adapters/models/llama/adapter_model.py @@ -8,6 +8,7 @@ from transformers.utils import add_start_docstrings from ...composition import adjust_tensors_for_parallel +from ...context import ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -52,6 +53,7 @@ def __init__(self, config): self.device_map = None self.post_init() + @ForwardContext.wrap def forward( self, input_ids=None, @@ -65,8 +67,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -75,7 +75,7 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.model( + outputs = self.model( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -86,14 +86,7 @@ def forward( output_attentions=output_attentions, return_dict=return_dict, output_hidden_states=output_hidden_states, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context - batch_size = outputs[0].shape[0] if self.config.pad_token_id is None: diff --git a/src/adapters/models/mbart/adapter_model.py b/src/adapters/models/mbart/adapter_model.py index 06e31650fa..0e88373c48 100644 --- a/src/adapters/models/mbart/adapter_model.py +++ b/src/adapters/models/mbart/adapter_model.py @@ -12,6 +12,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...composition import adjust_tensors_for_parallel +from ...context import ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -51,6 +52,7 @@ def get_decoder(self): return self.model.get_decoder() @add_start_docstrings_to_model_forward(MBART_INPUTS_DOCSTRING) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -69,8 +71,6 @@ def forward( return_dict=None, past_key_values=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): r""" @@ -83,7 +83,7 @@ def forward( if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs: use_cache = False - outputs, context = self.model( + outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -99,13 +99,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context # sequence classification based on last token in sequence x = outputs[0] # last hidden state if input_ids is not None and x.shape[1] == input_ids.shape[1]: diff --git a/src/adapters/models/mistral/adapter_model.py b/src/adapters/models/mistral/adapter_model.py index 595cace188..984506c127 100644 --- a/src/adapters/models/mistral/adapter_model.py +++ b/src/adapters/models/mistral/adapter_model.py @@ -7,6 +7,7 @@ from transformers.utils import add_start_docstrings from ...composition import adjust_tensors_for_parallel +from ...context import ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -51,6 +52,7 @@ def __init__(self, config): self.device_map = None self.post_init() + @ForwardContext.wrap def forward( self, input_ids=None, @@ -63,8 +65,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions @@ -73,7 +73,7 @@ def forward( ) return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.model( + outputs = self.model( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -83,14 +83,7 @@ def forward( output_attentions=output_attentions, return_dict=return_dict, output_hidden_states=output_hidden_states, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context - batch_size = outputs[0].shape[0] if self.config.pad_token_id is None: diff --git a/src/adapters/models/mt5/adapter_model.py b/src/adapters/models/mt5/adapter_model.py index 705d0852ef..73d1d80611 100644 --- a/src/adapters/models/mt5/adapter_model.py +++ b/src/adapters/models/mt5/adapter_model.py @@ -12,6 +12,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...composition import adjust_tensors_for_parallel +from ...context import ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin, Seq2SeqLMHead from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -63,6 +64,7 @@ def get_decoder(self): return self.transformer.decoder @add_start_docstrings_to_model_forward(MT5_INPUTS_DOCSTRING) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -82,8 +84,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -96,7 +96,7 @@ def forward( # decoder_input_ids from input_ids if no decoder_input_ids are provided decoder_input_ids = self._shift_right(input_ids) - model_output, context = self.transformer( + model_output = self.transformer( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -112,13 +112,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context sequence_output = model_output[0] # ToDo move head to device for parallel forward pass diff --git a/src/adapters/models/plbart/adapter_model.py b/src/adapters/models/plbart/adapter_model.py index 0475fd077d..8199c83280 100644 --- a/src/adapters/models/plbart/adapter_model.py +++ b/src/adapters/models/plbart/adapter_model.py @@ -11,6 +11,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward +from ...context import ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -50,6 +51,7 @@ def get_decoder(self): return self.model.get_decoder() @add_start_docstrings_to_model_forward(PLBART_INPUTS_DOCSTRING) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -68,8 +70,6 @@ def forward( return_dict=None, past_key_values=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): r""" @@ -82,7 +82,7 @@ def forward( if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs: use_cache = False - outputs, context = self.model( + outputs = self.model( input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -98,14 +98,7 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context - head_outputs = self.forward_head( outputs, head_name=head, diff --git a/src/adapters/models/roberta/adapter_model.py b/src/adapters/models/roberta/adapter_model.py index 5a9af959d8..e3eaf1baa7 100644 --- a/src/adapters/models/roberta/adapter_model.py +++ b/src/adapters/models/roberta/adapter_model.py @@ -7,7 +7,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...context import AdapterSetup +from ...context import AdapterSetup, ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -42,6 +42,7 @@ def __init__(self, config): self.init_weights() @add_start_docstrings_to_model_forward(ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -54,8 +55,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -70,7 +69,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.roberta( + outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -80,13 +79,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/t5/adapter_model.py b/src/adapters/models/t5/adapter_model.py index 5f2b324380..607d9c9ca8 100644 --- a/src/adapters/models/t5/adapter_model.py +++ b/src/adapters/models/t5/adapter_model.py @@ -7,6 +7,7 @@ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward from ...composition import adjust_tensors_for_parallel +from ...context import ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin, Seq2SeqLMHead from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -56,6 +57,7 @@ def get_decoder(self): return self.transformer.decoder @add_start_docstrings_to_model_forward(T5_INPUTS_DOCSTRING) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -75,8 +77,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict @@ -89,7 +89,7 @@ def forward( # decoder_input_ids from input_ids if no decoder_input_ids are provided decoder_input_ids = self._shift_right(input_ids) - model_output, context = self.transformer( + model_output = self.transformer( input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -105,13 +105,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context sequence_output = model_output[0] # ToDo move head to device for parallel forward pass diff --git a/src/adapters/models/vit/adapter_model.py b/src/adapters/models/vit/adapter_model.py index ece9ec5214..c08cef173a 100644 --- a/src/adapters/models/vit/adapter_model.py +++ b/src/adapters/models/vit/adapter_model.py @@ -10,7 +10,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...context import AdapterSetup +from ...context import AdapterSetup, ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...wrappers import init @@ -37,6 +37,7 @@ def __init__(self, config): self.post_init() @add_start_docstrings_to_model_forward(VIT_INPUTS_DOCSTRING) + @ForwardContext.wrap def forward( self, pixel_values: Optional[torch.Tensor] = None, @@ -52,20 +53,14 @@ def forward( ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.vit( + outputs = self.vit( pixel_values, head_mask=head_mask, output_attentions=output_attentions, output_hidden_states=output_hidden_states, interpolate_pos_encoding=interpolate_pos_encoding, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: diff --git a/src/adapters/models/whisper/adapter_model.py b/src/adapters/models/whisper/adapter_model.py index 4bcc026927..6a8adbe98e 100644 --- a/src/adapters/models/whisper/adapter_model.py +++ b/src/adapters/models/whisper/adapter_model.py @@ -12,6 +12,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward +from ...context import ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -52,6 +53,7 @@ def freeze_encoder(self): self.model.encoder._freeze_parameters() @add_start_docstrings_to_model_forward(WHISPER_INPUTS_DOCSTRING) + @ForwardContext.wrap def forward( self, input_features=None, @@ -69,8 +71,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): r""" @@ -94,7 +94,7 @@ def forward( # raise ValueError(ValueError: The following model_kwargs are not used by the model: ['labels'] # This is because we do not specify labels as parameter in the forward method - outputs, context = self.model( + outputs = self.model( input_features=input_features, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids, @@ -109,15 +109,8 @@ def forward( output_hidden_states=output_hidden_states, return_dict=return_dict, past_key_values=past_key_values, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context - head_outputs = self.forward_head( outputs, head_name=head, diff --git a/src/adapters/models/xlm_roberta/adapter_model.py b/src/adapters/models/xlm_roberta/adapter_model.py index 559202d52d..b4a592f16c 100644 --- a/src/adapters/models/xlm_roberta/adapter_model.py +++ b/src/adapters/models/xlm_roberta/adapter_model.py @@ -7,7 +7,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...context import AdapterSetup +from ...context import AdapterSetup, ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -43,6 +43,7 @@ def __init__(self, config): self.init_weights() @add_start_docstrings_to_model_forward(XLM_ROBERTA_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @ForwardContext.wrap def forward( self, input_ids=None, @@ -55,8 +56,6 @@ def forward( output_hidden_states=None, return_dict=None, head=None, - output_adapter_gating_scores=False, - output_adapter_fusion_attentions=False, **kwargs, ): input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None @@ -71,7 +70,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.roberta( + outputs = self.roberta( input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids, @@ -81,13 +80,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] diff --git a/src/adapters/models/xmod/adapter_model.py b/src/adapters/models/xmod/adapter_model.py index e81f49dee0..e40ff16746 100644 --- a/src/adapters/models/xmod/adapter_model.py +++ b/src/adapters/models/xmod/adapter_model.py @@ -11,7 +11,7 @@ ) from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward -from ...context import AdapterSetup +from ...context import AdapterSetup, ForwardContext from ...heads import ModelWithFlexibleHeadsAdaptersMixin from ...model_mixin import EmbeddingAdaptersWrapperMixin from ...wrappers import init @@ -47,6 +47,7 @@ def __init__(self, config): self.init_weights() @add_start_docstrings_to_model_forward(XMOD_INPUTS_DOCSTRING.format("batch_size, sequence_length")) + @ForwardContext.wrap def forward( self, input_ids: Optional[torch.Tensor] = None, @@ -60,8 +61,6 @@ def forward( output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, head: Optional[str] = None, - output_adapter_gating_scores: Optional[bool] = False, - output_adapter_fusion_attentions: Optional[bool] = False, **kwargs, ): # Flatten for multiple choice tasks @@ -78,7 +77,7 @@ def forward( return_dict = return_dict if return_dict is not None else self.config.use_return_dict - outputs, context = self.roberta( + outputs = self.roberta( input_ids, lang_ids=lang_ids, attention_mask=attention_mask, @@ -89,13 +88,7 @@ def forward( output_attentions=output_attentions, output_hidden_states=output_hidden_states, return_dict=return_dict, - output_adapter_gating_scores=output_adapter_gating_scores, - output_adapter_fusion_attentions=output_adapter_fusion_attentions, - adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False), - output_context=True, ) - # required e.g. for prompt tuning in all models - kwargs["context"] = context # BERT & RoBERTa return the pooled output as second item, we don't need that in these heads if not return_dict: head_inputs = (outputs[0],) + outputs[2:] From 096297a6c2074c0bd8ca79a73083472039424373 Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 2 Feb 2025 20:26:06 +0100 Subject: [PATCH 2/8] fix signature --- src/adapters/model_mixin.py | 2 -- src/adapters/wrappers/model.py | 4 ++++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/adapters/model_mixin.py b/src/adapters/model_mixin.py index c4ec9514d5..8303191a52 100644 --- a/src/adapters/model_mixin.py +++ b/src/adapters/model_mixin.py @@ -1703,7 +1703,6 @@ def post_embedding_forward(self, module, args, embedding_output): @ForwardContext.wrap_base def forward(self, *args, **kwargs): - print("base context: ", ForwardContext.get_context().__dict__) return super().forward(*args, **kwargs) @@ -2245,5 +2244,4 @@ def freeze_embeddings(self, freeze=True): @ForwardContext.wrap def forward(self, *args, **kwargs): - print("head context: ", ForwardContext.get_context().__dict__) return super().forward(*args, **kwargs) diff --git a/src/adapters/wrappers/model.py b/src/adapters/wrappers/model.py index 12ed79e122..1f54e29ca1 100644 --- a/src/adapters/wrappers/model.py +++ b/src/adapters/wrappers/model.py @@ -1,4 +1,5 @@ import importlib +import inspect import os from typing import Any, Optional, Type, Union @@ -78,6 +79,8 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] if hasattr(model, "base_model_prefix") and hasattr(model, model.base_model_prefix): base_model = getattr(model, model.base_model_prefix) if isinstance(base_model, ModelAdaptersMixin): + # HACK to preserve original forward method signature (e.g. for Trainer label names) + temp_signature = inspect.signature(model.forward.__func__) # Create new wrapper model class model_class_name = model.__class__.__name__ model_class = type( @@ -86,6 +89,7 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] {}, ) model.__class__ = model_class + model.forward.__func__.__signature__ = temp_signature # Finally, initialize adapters model.init_adapters(model.config, adapters_config) From 8f17f5eb5409a90a15bb229f01a1a19725a95311 Mon Sep 17 00:00:00 2001 From: calpt Date: Mon, 10 Feb 2025 22:51:25 +0100 Subject: [PATCH 3/8] list -> set --- src/adapters/context.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/adapters/context.py b/src/adapters/context.py index df0ae91f45..4dc0f3d1e5 100644 --- a/src/adapters/context.py +++ b/src/adapters/context.py @@ -79,15 +79,15 @@ class ForwardContext(ContextManager): # thread-local storage that holds a stack of active contexts storage = threading.local() - context_args = [ + context_args = { "output_adapter_gating_scores", "output_adapter_fusion_attentions", "adapter_input_parallelized", - ] - context_attributes = [ + } + context_attributes = { "adapter_gating_scores", "adapter_fusion_attentions", - ] + } # Additional used attributes not exposed to the user # - prompt_tokens_length: length of the prompt tokens From b38c4a92b679c5f7516b63ea0b4bd3c1e766dfcf Mon Sep 17 00:00:00 2001 From: calpt Date: Fri, 21 Feb 2025 16:07:14 +0100 Subject: [PATCH 4/8] Add ForwardContext args to wrapped signature --- src/adapters/wrappers/model.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/adapters/wrappers/model.py b/src/adapters/wrappers/model.py index 1f54e29ca1..7a79631670 100644 --- a/src/adapters/wrappers/model.py +++ b/src/adapters/wrappers/model.py @@ -10,6 +10,7 @@ from transformers.models.auto.configuration_auto import model_type_to_module_name from ..configuration import ModelAdaptersConfig +from ..context import ForwardContext from ..model_mixin import ( EmbeddingAdaptersWrapperMixin, ModelAdaptersMixin, @@ -81,6 +82,11 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] if isinstance(base_model, ModelAdaptersMixin): # HACK to preserve original forward method signature (e.g. for Trainer label names) temp_signature = inspect.signature(model.forward.__func__) + params = list(temp_signature.parameters.values()) + # add forward context args to signature + for param_name in ForwardContext.context_args: + params.append(inspect.Parameter(param_name, inspect.Parameter.KEYWORD_ONLY)) + temp_signature = temp_signature.replace(parameters=params) # Create new wrapper model class model_class_name = model.__class__.__name__ model_class = type( From f03d9b74cde929802bbe8c6a7cf45f9ec94758e5 Mon Sep 17 00:00:00 2001 From: calpt Date: Fri, 21 Feb 2025 16:24:40 +0100 Subject: [PATCH 5/8] docs --- src/adapters/context.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/adapters/context.py b/src/adapters/context.py index 4dc0f3d1e5..eeadd99cd8 100644 --- a/src/adapters/context.py +++ b/src/adapters/context.py @@ -106,6 +106,9 @@ def __exit__(self, type, value, traceback): ForwardContext.get_contexts().pop() def _call_forward(self, model, f, *args, **kwargs): + """ + Calls the forward function of the model with the given arguments and keyword arguments. + """ kwargs = {k: v for k, v in kwargs.items() if k not in self.context_args} results = f(model, *args, **kwargs) @@ -123,6 +126,10 @@ def _call_forward(self, model, f, *args, **kwargs): @classmethod def wrap_base(cls, f): + """ + Decorator method that wraps a ``forward()`` function of a base model class. + Unlike ``wrap()``, this method does not create a new context if the is an existing one. + """ @functools.wraps(f) def wrapper_func(self, *args, **kwargs): From c8996fc7133281085e77d81540bc03d733462903 Mon Sep 17 00:00:00 2001 From: calpt Date: Fri, 21 Feb 2025 16:36:22 +0100 Subject: [PATCH 6/8] fix --- src/adapters/wrappers/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adapters/wrappers/model.py b/src/adapters/wrappers/model.py index 7a79631670..35a7ae03fe 100644 --- a/src/adapters/wrappers/model.py +++ b/src/adapters/wrappers/model.py @@ -85,7 +85,7 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] params = list(temp_signature.parameters.values()) # add forward context args to signature for param_name in ForwardContext.context_args: - params.append(inspect.Parameter(param_name, inspect.Parameter.KEYWORD_ONLY)) + params.append(inspect.Parameter(param_name, inspect.Parameter.VAR_KEYWORD)) temp_signature = temp_signature.replace(parameters=params) # Create new wrapper model class model_class_name = model.__class__.__name__ From 1667e73eef395666378d013bd990501776e69466 Mon Sep 17 00:00:00 2001 From: calpt Date: Sun, 23 Feb 2025 16:28:10 +0000 Subject: [PATCH 7/8] change signature param type --- src/adapters/wrappers/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/adapters/wrappers/model.py b/src/adapters/wrappers/model.py index 35a7ae03fe..14696eda4b 100644 --- a/src/adapters/wrappers/model.py +++ b/src/adapters/wrappers/model.py @@ -85,7 +85,7 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] params = list(temp_signature.parameters.values()) # add forward context args to signature for param_name in ForwardContext.context_args: - params.append(inspect.Parameter(param_name, inspect.Parameter.VAR_KEYWORD)) + params.append(inspect.Parameter(param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None)) temp_signature = temp_signature.replace(parameters=params) # Create new wrapper model class model_class_name = model.__class__.__name__ From d4756a4687847112a3cc207c47719b50f490b5f6 Mon Sep 17 00:00:00 2001 From: calpt Date: Wed, 26 Feb 2025 20:49:44 +0100 Subject: [PATCH 8/8] revert signature editing --- src/adapters/wrappers/model.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/adapters/wrappers/model.py b/src/adapters/wrappers/model.py index 14696eda4b..1f54e29ca1 100644 --- a/src/adapters/wrappers/model.py +++ b/src/adapters/wrappers/model.py @@ -10,7 +10,6 @@ from transformers.models.auto.configuration_auto import model_type_to_module_name from ..configuration import ModelAdaptersConfig -from ..context import ForwardContext from ..model_mixin import ( EmbeddingAdaptersWrapperMixin, ModelAdaptersMixin, @@ -82,11 +81,6 @@ def init(model: PreTrainedModel, adapters_config: Optional[ModelAdaptersConfig] if isinstance(base_model, ModelAdaptersMixin): # HACK to preserve original forward method signature (e.g. for Trainer label names) temp_signature = inspect.signature(model.forward.__func__) - params = list(temp_signature.parameters.values()) - # add forward context args to signature - for param_name in ForwardContext.context_args: - params.append(inspect.Parameter(param_name, inspect.Parameter.POSITIONAL_OR_KEYWORD, default=None)) - temp_signature = temp_signature.replace(parameters=params) # Create new wrapper model class model_class_name = model.__class__.__name__ model_class = type(