Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap ForwardContext around full model forward #789

Merged
merged 8 commits into from
Feb 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 48 additions & 27 deletions src/adapters/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,22 @@ class ForwardContext(ContextManager):
# thread-local storage that holds a stack of active contexts
storage = threading.local()

context_attributes = [
context_args = {
"output_adapter_gating_scores",
"output_adapter_fusion_attentions",
"adapter_input_parallelized",
}
context_attributes = {
"adapter_gating_scores",
"adapter_fusion_attentions",
"adapter_input_parallelized",
]
}
# Additional used attributes not exposed to the user
# - prompt_tokens_length: length of the prompt tokens

def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
for arg_name in self.context_args:
setattr(self, arg_name, kwargs.pop(arg_name, None))
if hasattr(model, "forward_context"):
model.forward_context(self, *args, **kwargs)

Expand All @@ -99,6 +105,43 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
ForwardContext.get_contexts().pop()

def _call_forward(self, model, f, *args, **kwargs):
"""
Calls the forward function of the model with the given arguments and keyword arguments.
"""
kwargs = {k: v for k, v in kwargs.items() if k not in self.context_args}
results = f(model, *args, **kwargs)

# append output attributes
if isinstance(results, tuple):
for attr in self.context_attributes:
if getattr(self, "output_" + attr, False):
results = results + (dict(getattr(self, attr)),)
else:
for attr in self.context_attributes:
if getattr(self, "output_" + attr, False):
results[attr] = dict(getattr(self, attr))

return results

@classmethod
def wrap_base(cls, f):
"""
Decorator method that wraps a ``forward()`` function of a base model class.
Unlike ``wrap()``, this method does not create a new context if the is an existing one.
"""

@functools.wraps(f)
def wrapper_func(self, *args, **kwargs):
if self.adapters_config is not None and ForwardContext.get_context() is None:
with cls(self, *args, **kwargs) as ctx:
results = ctx._call_forward(self, f, *args, **kwargs)
return results
else:
return f(self, *args, **kwargs)

return wrapper_func

@classmethod
def wrap(cls, f):
"""
Expand All @@ -109,30 +152,8 @@ def wrap(cls, f):
def wrapper_func(self, *args, **kwargs):
if self.adapters_config is not None:
with cls(self, *args, **kwargs) as ctx:
# whether to output the context attributes
output_context = kwargs.pop("output_context", False)
kwargs = {
k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes
}
results = f(self, *args, **kwargs)

# append output attributes
if isinstance(results, tuple):
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results = results + (dict(getattr(ctx, attr)),)
else:
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results[attr] = dict(getattr(ctx, attr))

if output_context:
context_dict = ctx.__dict__

if output_context:
return results, context_dict
else:
return results
results = ctx._call_forward(self, f, *args, **kwargs)
return results
else:
return f(self, *args, **kwargs)

Expand Down
6 changes: 5 additions & 1 deletion src/adapters/heads/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,12 @@ def _get_head_input(outputs, cls_out, batch):
kwargs["invertible_adapter"] = inv_adapter

# Set prompt tokens length
context = context or ForwardContext.get_context()
if context is not None:
prompt_tokens_length = context.get("prompt_tokens_length", None)
if isinstance(context, ForwardContext):
prompt_tokens_length = getattr(context, "prompt_tokens_length", None)
else:
prompt_tokens_length = context.get("prompt_tokens_length", None)
if prompt_tokens_length is not None:
kwargs["prompt_tokens_length"] = prompt_tokens_length

Expand Down
11 changes: 6 additions & 5 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,8 +1138,7 @@ def forward_context(self, context: ForwardContext, *args, **kwargs):

context.adapters_parallelized = False
# Check if already parallelized in encoder
adapter_input_parallelized = kwargs.pop("adapter_input_parallelized", None)
if adapter_input_parallelized:
if context.adapter_input_parallelized:
if active_adapters.parallel_channels > 1:
context.adapters_parallelized = True
# Add the shared parameters for the active adapters to the context
Expand All @@ -1165,8 +1164,6 @@ def forward_context(self, context: ForwardContext, *args, **kwargs):
context.offsets = attention_mask.argmax(1)

# Adapter gating and attention outputs
context.output_adapter_gating_scores = kwargs.get("output_adapter_gating_scores", False)
context.output_adapter_fusion_attentions = kwargs.get("output_adapter_fusion_attentions", False)
context.adapter_gating_scores = defaultdict(dict)
context.adapter_fusion_attentions = defaultdict(dict)

Expand Down Expand Up @@ -1704,7 +1701,7 @@ def post_embedding_forward(self, module, args, embedding_output):

return embedding_output

@ForwardContext.wrap
@ForwardContext.wrap_base
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)

Expand Down Expand Up @@ -2244,3 +2241,7 @@ def freeze_embeddings(self, freeze=True):
else:
for p in self.get_output_embeddings().parameters():
p.requires_grad = not freeze

@ForwardContext.wrap
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
14 changes: 3 additions & 11 deletions src/adapters/models/albert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...context import AdapterSetup, ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(self, config):
self.init_weights()

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -50,8 +51,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
Expand All @@ -66,7 +65,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.albert(
outputs = self.albert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -76,14 +75,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

# BERT & RoBERTa & ALBERT return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
12 changes: 3 additions & 9 deletions src/adapters/models/bart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -50,6 +51,7 @@ def get_decoder(self):
return self.model.get_decoder()

@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -68,8 +70,6 @@ def forward(
return_dict=None,
past_key_values=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
r"""
Expand All @@ -82,7 +82,7 @@ def forward(
if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs:
use_cache = False

outputs, context = self.model(
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
Expand All @@ -98,13 +98,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

head_outputs = self.forward_head(
outputs,
Expand Down
14 changes: 3 additions & 11 deletions src/adapters/models/beit/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...context import AdapterSetup, ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...wrappers import init

Expand Down Expand Up @@ -51,6 +51,7 @@ def make_inputs_require_grads(module, input, output):
self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)

@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@ForwardContext.wrap
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
Expand All @@ -60,27 +61,18 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.beit(
outputs = self.beit(
pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
13 changes: 3 additions & 10 deletions src/adapters/models/bert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...context import AdapterSetup, ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(self, config):
self.init_weights()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -55,8 +56,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
Expand All @@ -71,7 +70,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.bert(
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -81,13 +80,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
13 changes: 3 additions & 10 deletions src/adapters/models/bert_generation/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...context import AdapterSetup, ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(self, config):
self.init_weights()

@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -53,8 +54,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
Expand All @@ -68,7 +67,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.bert(
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
Expand All @@ -81,13 +80,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
Loading