Skip to content

Commit

Permalink
Merge branch 'adapter-hub:main' into dev/seed-for-weight-init-653
Browse files Browse the repository at this point in the history
  • Loading branch information
TimoImhof authored Feb 26, 2025
2 parents 98e5e01 + acca075 commit c5c2981
Show file tree
Hide file tree
Showing 27 changed files with 132 additions and 261 deletions.
75 changes: 48 additions & 27 deletions src/adapters/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,22 @@ class ForwardContext(ContextManager):
# thread-local storage that holds a stack of active contexts
storage = threading.local()

context_attributes = [
context_args = {
"output_adapter_gating_scores",
"output_adapter_fusion_attentions",
"adapter_input_parallelized",
}
context_attributes = {
"adapter_gating_scores",
"adapter_fusion_attentions",
"adapter_input_parallelized",
]
}
# Additional used attributes not exposed to the user
# - prompt_tokens_length: length of the prompt tokens

def __init__(self, model, *args, **kwargs):
# If the model has a method ``forward_context()``, use it to create the context.
for arg_name in self.context_args:
setattr(self, arg_name, kwargs.pop(arg_name, None))
if hasattr(model, "forward_context"):
model.forward_context(self, *args, **kwargs)

Expand All @@ -99,6 +105,43 @@ def __enter__(self):
def __exit__(self, type, value, traceback):
ForwardContext.get_contexts().pop()

def _call_forward(self, model, f, *args, **kwargs):
"""
Calls the forward function of the model with the given arguments and keyword arguments.
"""
kwargs = {k: v for k, v in kwargs.items() if k not in self.context_args}
results = f(model, *args, **kwargs)

# append output attributes
if isinstance(results, tuple):
for attr in self.context_attributes:
if getattr(self, "output_" + attr, False):
results = results + (dict(getattr(self, attr)),)
else:
for attr in self.context_attributes:
if getattr(self, "output_" + attr, False):
results[attr] = dict(getattr(self, attr))

return results

@classmethod
def wrap_base(cls, f):
"""
Decorator method that wraps a ``forward()`` function of a base model class.
Unlike ``wrap()``, this method does not create a new context if the is an existing one.
"""

@functools.wraps(f)
def wrapper_func(self, *args, **kwargs):
if self.adapters_config is not None and ForwardContext.get_context() is None:
with cls(self, *args, **kwargs) as ctx:
results = ctx._call_forward(self, f, *args, **kwargs)
return results
else:
return f(self, *args, **kwargs)

return wrapper_func

@classmethod
def wrap(cls, f):
"""
Expand All @@ -109,30 +152,8 @@ def wrap(cls, f):
def wrapper_func(self, *args, **kwargs):
if self.adapters_config is not None:
with cls(self, *args, **kwargs) as ctx:
# whether to output the context attributes
output_context = kwargs.pop("output_context", False)
kwargs = {
k: v for k, v in kwargs.items() if k.replace("output_", "") not in cls.context_attributes
}
results = f(self, *args, **kwargs)

# append output attributes
if isinstance(results, tuple):
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results = results + (dict(getattr(ctx, attr)),)
else:
for attr in cls.context_attributes:
if getattr(ctx, "output_" + attr, False):
results[attr] = dict(getattr(ctx, attr))

if output_context:
context_dict = ctx.__dict__

if output_context:
return results, context_dict
else:
return results
results = ctx._call_forward(self, f, *args, **kwargs)
return results
else:
return f(self, *args, **kwargs)

Expand Down
6 changes: 5 additions & 1 deletion src/adapters/heads/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -584,8 +584,12 @@ def _get_head_input(outputs, cls_out, batch):
kwargs["invertible_adapter"] = inv_adapter

# Set prompt tokens length
context = context or ForwardContext.get_context()
if context is not None:
prompt_tokens_length = context.get("prompt_tokens_length", None)
if isinstance(context, ForwardContext):
prompt_tokens_length = getattr(context, "prompt_tokens_length", None)
else:
prompt_tokens_length = context.get("prompt_tokens_length", None)
if prompt_tokens_length is not None:
kwargs["prompt_tokens_length"] = prompt_tokens_length

Expand Down
11 changes: 6 additions & 5 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1138,8 +1138,7 @@ def forward_context(self, context: ForwardContext, *args, **kwargs):

context.adapters_parallelized = False
# Check if already parallelized in encoder
adapter_input_parallelized = kwargs.pop("adapter_input_parallelized", None)
if adapter_input_parallelized:
if context.adapter_input_parallelized:
if active_adapters.parallel_channels > 1:
context.adapters_parallelized = True
# Add the shared parameters for the active adapters to the context
Expand All @@ -1165,8 +1164,6 @@ def forward_context(self, context: ForwardContext, *args, **kwargs):
context.offsets = attention_mask.argmax(1)

# Adapter gating and attention outputs
context.output_adapter_gating_scores = kwargs.get("output_adapter_gating_scores", False)
context.output_adapter_fusion_attentions = kwargs.get("output_adapter_fusion_attentions", False)
context.adapter_gating_scores = defaultdict(dict)
context.adapter_fusion_attentions = defaultdict(dict)

Expand Down Expand Up @@ -1704,7 +1701,7 @@ def post_embedding_forward(self, module, args, embedding_output):

return embedding_output

@ForwardContext.wrap
@ForwardContext.wrap_base
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)

Expand Down Expand Up @@ -2244,3 +2241,7 @@ def freeze_embeddings(self, freeze=True):
else:
for p in self.get_output_embeddings().parameters():
p.requires_grad = not freeze

@ForwardContext.wrap
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
14 changes: 3 additions & 11 deletions src/adapters/models/albert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...context import AdapterSetup, ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(self, config):
self.init_weights()

@add_start_docstrings_to_model_forward(ALBERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -50,8 +51,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
Expand All @@ -66,7 +65,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.albert(
outputs = self.albert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -76,14 +75,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

# BERT & RoBERTa & ALBERT return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
12 changes: 3 additions & 9 deletions src/adapters/models/bart/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -50,6 +51,7 @@ def get_decoder(self):
return self.model.get_decoder()

@add_start_docstrings_to_model_forward(BART_INPUTS_DOCSTRING)
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -68,8 +70,6 @@ def forward(
return_dict=None,
past_key_values=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
r"""
Expand All @@ -82,7 +82,7 @@ def forward(
if "labels" in kwargs or "start_positions" in kwargs and "end_positions" in kwargs:
use_cache = False

outputs, context = self.model(
outputs = self.model(
input_ids,
attention_mask=attention_mask,
decoder_input_ids=decoder_input_ids,
Expand All @@ -98,13 +98,7 @@ def forward(
output_hidden_states=output_hidden_states,
return_dict=return_dict,
past_key_values=past_key_values,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

head_outputs = self.forward_head(
outputs,
Expand Down
14 changes: 3 additions & 11 deletions src/adapters/models/beit/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...context import AdapterSetup, ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...wrappers import init

Expand Down Expand Up @@ -51,6 +51,7 @@ def make_inputs_require_grads(module, input, output):
self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)

@add_start_docstrings_to_model_forward(BEIT_INPUTS_DOCSTRING)
@ForwardContext.wrap
def forward(
self,
pixel_values: Optional[torch.Tensor] = None,
Expand All @@ -60,27 +61,18 @@ def forward(
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.beit(
outputs = self.beit(
pixel_values,
bool_masked_pos=bool_masked_pos,
head_mask=head_mask,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context

# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
13 changes: 3 additions & 10 deletions src/adapters/models/bert/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...context import AdapterSetup, ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -43,6 +43,7 @@ def __init__(self, config):
self.init_weights()

@add_start_docstrings_to_model_forward(BERT_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -55,8 +56,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
Expand All @@ -71,7 +70,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.bert(
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
Expand All @@ -81,13 +80,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
13 changes: 3 additions & 10 deletions src/adapters/models/bert_generation/adapter_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
)
from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward

from ...context import AdapterSetup
from ...context import AdapterSetup, ForwardContext
from ...heads import ModelWithFlexibleHeadsAdaptersMixin
from ...model_mixin import EmbeddingAdaptersWrapperMixin
from ...wrappers import init
Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(self, config):
self.init_weights()

@add_start_docstrings_to_model_forward(BERT_GENERATION_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
@ForwardContext.wrap
def forward(
self,
input_ids=None,
Expand All @@ -53,8 +54,6 @@ def forward(
output_hidden_states=None,
return_dict=None,
head=None,
output_adapter_gating_scores=False,
output_adapter_fusion_attentions=False,
**kwargs,
):
input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
Expand All @@ -68,7 +67,7 @@ def forward(

return_dict = return_dict if return_dict is not None else self.config.use_return_dict

outputs, context = self.bert(
outputs = self.bert(
input_ids,
attention_mask=attention_mask,
position_ids=position_ids,
Expand All @@ -81,13 +80,7 @@ def forward(
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
output_adapter_gating_scores=output_adapter_gating_scores,
output_adapter_fusion_attentions=output_adapter_fusion_attentions,
adapter_input_parallelized=kwargs.pop("adapter_input_parallelized", False),
output_context=True,
)
# required e.g. for prompt tuning in all models
kwargs["context"] = context
# BERT & RoBERTa return the pooled output as second item, we don't need that in these heads
if not return_dict:
head_inputs = (outputs[0],) + outputs[2:]
Expand Down
Loading

0 comments on commit c5c2981

Please sign in to comment.