Skip to content

Commit

Permalink
Wrap ForwardContext around full model forward (#789)
Browse files Browse the repository at this point in the history
This PR adapts the ForwardContext to be applied to the full model
(including head) forward pass. The original base model forward wrapper
is now moved to `wrap_base` to make sure no second ForwardContext is
created for a single forward pass.

This enables passing custom args that are defined in the ForwardContext
definition to the top-level model call, as discussed in #783, e.g.:
```python
model = AutoModelForCausalLM.from_pretrained(model_name)
adapters.init(model)
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(["This is a test text"], return_tensors="pt")

# Registers new forward args globally
ForwardContext.context_args.add("task_ids")

# New the new arg name can be used w/o modifying the model's forward method
output = model(**inputs, task_ids=["id_0", "id_1"])
```
In the example above, the forward context will automatically add the
passed context args as attributes, ie. they can be accessed within the
foward pass like this:
```python
task_ids = ForwardContext.get_context().task_ids
```
  • Loading branch information
calpt authored Feb 26, 2025
1 parent e5a8689 commit acca075
Show file tree
Hide file tree
Showing 27 changed files with 132 additions and 261 deletions.
75 changes: 48 additions & 27 deletions src/adapters/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,22 @@ class ForwardContext(ContextManager):
# thread-local storage that holds a stack of active contexts
storage = threading.local()

context_attributes = [
context_args = {
"output_adapter_gating_scores",
"output_adapter_fusion_attentions",
"adapter_input_parallelized",
}
context_attributes = {
"adapter_gating_scores",
"adapter_fusion_attentions",
"adapter_input_parallelized",
]
}
# Additional used attributes not exposed to the user
# - prompt_tokens_length: length of the prompt tokens

def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
for arg_name in self.context_args:
setattr(self, arg_name, kwargs.pop(arg_name, None))
if hasattr(model, "forward_context"):
model.forward_context(self, *args, **kwargs)

Expand All @@ -99,6 +105,43 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
ForwardContext.get_contexts().pop()

def _call_forward(self, model, f, *args, **kwargs):
"""
Calls the forward function of the model with the given arguments and keyword arguments.
"""
kwargs = {k: v for k, v in kwargs.items() if k not in self.context_args}
results = f(model, *args, **kwargs)

# append output attributes
if isinstance(results, tuple):
for attr in self.context_attributes:
if getattr(self, "output_" + attr, False):
results = results + (dict(getattr(self, attr)),)
else:
for attr in self.context_attributes:
if getattr(self, "output_" + attr, False):
results[attr] = dict(getattr(self, attr))

return results

@classmethod
def wrap_base(cls, f):
"""
Decorator method that wraps a ``forward()`` function of a base model class.
Unlike ``wrap()``, this method does not create a new context if the is an existing one.
"""

@functools.wraps(f)
def wrapper_func(self, *args, **kwargs):
if self.adapters_config is not None and ForwardContext.get_context() is None:
with cls(self, *args, **kwargs) as ctx:
results = ctx._call_forward(self, f, *args, **kwargs)
return results
else:
return f(self, *args, **kwargs)

return wrapper_func

@classmethod
def wrap(cls, f):
"""
Expand All @@ -109,30 +152,8 @@ def wrap(cls, f):
def wrapper_func(self, *args, **kwargs):
if self.adapters_config is not None:
with cls(self, *args, **kwargs) as ctx:
# whether to output the context attributes
output_context = kwargs.pop("output_context", False)
kwargs = {
k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes
}
results = f(self, *args, **kwargs)

# append output attributes
if isinstance(results, tuple):
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results = results + (dict(getattr(ctx, attr)),)
else:
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results[attr] = dict(getattr(ctx, attr))

if output_context:
context_dict = ctx.__dict__

if output_context:
return results, context_dict
else:
return results
results = ctx._call_forward(self, f, *args, **kwargs)
return results
else:
return f(self, *args, **kwargs)

Expand Down
6 changes: 5 additions & 1 deletion src/adapters/heads/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,12 @@ def _get_head_input(outputs, cls_out, batch):
kwargs["invertible_adapter"] = inv_adapter

# Set prompt tokens length
context = context or ForwardContext.get_context()
if context is not None:
prompt_tokens_length = context.get("prompt_tokens_length", None)
if isinstance(context, ForwardContext):
prompt_tokens_length = getattr(context, "prompt_tokens_length", None)
else:
prompt_tokens_length = context.get("prompt_tokens_length", None)
if prompt_tokens_length is not None:
kwargs["prompt_tokens_length"] = prompt_tokens_length

Expand Down
11 changes: 6 additions & 5 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,8 +1138,7 @@ def forward_context(self, context: ForwardContext, *args, **kwargs):

context.adapters_parallelized = False
# Check if already parallelized in encoder
adapter_input_parallelized = kwargs.pop("adapter_input_parallelized", None)
if adapter_input_parallelized:
if context.adapter_input_parallelized:
if active_adapters.parallel_channels > 1:
context.adapters_parallelized = True
# Add the shared parameters for the active adapters to the context
Expand All @@ -1165,8 +1164,6 @@ def forward_context(self, context: ForwardContext, *args, **kwargs):
context.offsets = attention_mask.argmax(1)

# Adapter gating and attention outputs
context.output_adapter_gating_scores = kwargs.get("output_adapter_gating_scores", False)
context.output_adapter_fusion_attentions = kwargs.get("output_adapter_fusion_attentions", False)
context.adapter_gating_scores = defaultdict(dict)
context.adapter_fusion_attentions = defaultdict(dict)

Expand Down Expand Up @@ -1704,7 +1701,7 @@ def post_embedding_forward(self, module, args, embedding_output):

return embedding_output

@ForwardContext.wrap
@ForwardContext.wrap_base
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)

Expand Down Expand Up @@ -2244,3 +2241,7 @@ def freeze_embeddings(self, freeze=True):
else:
for p in self.get_output_embeddings().parameters():
p.requires_grad = not freeze

@ForwardContext.wrap
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
14 changes: 3 additions & 11 deletions src/adapters/models/albert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...context import AdapterSetup, ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(self, config):
self.init_weights()

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -50,8 +51,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
Expand All @@ -66,7 +65,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.albert(
outputs = self.albert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -76,14 +75,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

# BERT & RoBERTa & ALBERT return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
12 changes: 3 additions & 9 deletions src/adapters/models/bart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -50,6 +51,7 @@ def get_decoder(self):
return self.model.get_decoder()

@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -68,8 +70,6 @@ def forward(
return_dict=None,
past_key_values=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
r"""
Expand All @@ -82,7 +82,7 @@ def forward(
if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs:
use_cache = False

outputs, context = self.model(
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
Expand All @@ -98,13 +98,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

head_outputs = self.forward_head(
outputs,
Expand Down
14 changes: 3 additions & 11 deletions src/adapters/models/beit/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...context import AdapterSetup, ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...wrappers import init

Expand Down Expand Up @@ -51,6 +51,7 @@ def make_inputs_require_grads(module, input, output):
self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)

@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@ForwardContext.wrap
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
Expand All @@ -60,27 +61,18 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.beit(
outputs = self.beit(
pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
13 changes: 3 additions & 10 deletions src/adapters/models/bert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...context import AdapterSetup, ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(self, config):
self.init_weights()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -55,8 +56,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
Expand All @@ -71,7 +70,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.bert(
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -81,13 +80,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
13 changes: 3 additions & 10 deletions src/adapters/models/bert_generation/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...context import AdapterSetup, ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(self, config):
self.init_weights()

@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -53,8 +54,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
Expand All @@ -68,7 +67,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.bert(
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
Expand All @@ -81,13 +80,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
Loading

0 comments on commit acca075

Please sign in to comment.