Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Wrap ForwardContext around full model forward #789

Merged
merged 8 commits into from
Feb 26, 2025

Conversation

calpt
Copy link
Member

@calpt calpt commented Feb 2, 2025

This PR adapts the ForwardContext to be applied to the full model (including head) forward pass. The original base model forward wrapper is now moved to wrap_base to make sure no second ForwardContext is created for a single forward pass.

This enables passing custom args that are defined in the ForwardContext definition to the top-level model call, as discussed in #783, e.g.:

model = AutoModelForCausalLM.from_pretrained(model_name)
adapters.init(model)
tokenizer = AutoTokenizer.from_pretrained(model_name)
inputs = tokenizer(["This is a test text"], return_tensors="pt")

# Registers new forward args globally
ForwardContext.context_args.add("task_ids")

# New the new arg name can be used w/o modifying the model's forward method
output = model(**inputs, task_ids=["id_0", "id_1"])

In the example above, the forward context will automatically add the passed context args as attributes, ie. they can be accessed within the foward pass like this:

task_ids = ForwardContext.get_context().task_ids

@FrLdy
Copy link
Contributor

FrLdy commented Feb 2, 2025

Hello @calpt,
Thanks for the PR !
Just tested it out and it works great !
This definitely makes passing custom args much easier, no need to redefine a new forward for each specific hf model.
I'll rebase my fork once the changes are merged and propose my additions for multi-task support.

@calpt calpt marked this pull request as ready for review February 4, 2025 14:11
@calpt calpt linked an issue Feb 4, 2025 that may be closed by this pull request
@FrLdy
Copy link
Contributor

FrLdy commented Feb 8, 2025

Hi,
I tried to include a new ForwardContext in my test for multi-task composition:
https://github.com/FrLdy/adapters/blob/6f63c4df35d8deed973f4c791f6b779f8ba4f668/tests/test_misc/test_adapter_composition.py#L177

However, if I don't check whether the new ForwardContext argument (task_ids) is present in context_args, the test runs fine when each TestCase class is executed independently. But when they are run together, only the first one passes, while the subsequent ones fail.

Maybe use a set to store ForwardContext.context_args ?

@FrLdy FrLdy mentioned this pull request Feb 8, 2025
5 tasks
@calpt
Copy link
Member Author

calpt commented Feb 11, 2025

Maybe use a set to store ForwardContext.context_args ?

makes sense, done that

@FrLdy FrLdy mentioned this pull request Feb 15, 2025
18 tasks
@FrLdy
Copy link
Contributor

FrLdy commented Feb 17, 2025

Hello,
I tried to use added context_args with generate method.
Unfortunately, context arguments are rejected by _validate_model_kwargs method from GenerationMixin.

import transformers
import adapters
import torch, torch.nn as nn

model = transformers.T5ForConditionalGeneration(
    transformers.T5Config(
        num_layers=2,
        num_decoder_layers=2
    )
)
adapters.init(model)

model.add_adapter("a")

adapters.ForwardContext.context_args.add("task_ids")
model.generate(input_ids=torch.randint(0, 1000, (3,128)), task_ids=torch.randint(0, 3, (3,)))

which returns

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[20], line 16
     13 model.add_adapter("a", overwrite_ok=True)
     15 adapters.ForwardContext.context_args.add("task_ids")
---> 16 model.generate(input_ids=torch.randint(0, 1000, (3,128)), task_ids=torch.randint(0, 3, (3,)))

File ~/Documents/labo/etr-peft-composition/.venv/lib/python3.10/site-packages/torch/utils/_contextlib.py:116, in context_decorator.<locals>.decorate_context(*args, **kwargs)
    113 @functools.wraps(func)
    114 def decorate_context(*args, **kwargs):
    115     with ctx_factory():
--> 116         return func(*args, **kwargs)

File ~/Documents/labo/etr-peft-composition/src/adapters/hf_transformers/src/transformers/generation/utils.py:2009, in GenerationMixin.generate(self, inputs, generation_config, logits_processor, stopping_criteria, prefix_allowed_tokens_fn, synced_gpus, assistant_model, streamer, negative_prompt_ids, negative_prompt_attention_mask, **kwargs)
   2006 assistant_tokenizer = kwargs.pop("assistant_tokenizer", None)  # only used for assisted generation
   2008 generation_config, model_kwargs = self._prepare_generation_config(generation_config, **kwargs)
-> 2009 self._validate_model_kwargs(model_kwargs.copy())
   2010 self._validate_assistant(assistant_model, tokenizer, assistant_tokenizer)
   2012 # 2. Set generation parameters if not already defined

File ~/Documents/labo/etr-peft-composition/src/adapters/hf_transformers/src/transformers/generation/utils.py:1388, in GenerationMixin._validate_model_kwargs(self, model_kwargs)
   1385         unused_model_args.append(key)
   1387 if unused_model_args:
-> 1388     raise ValueError(
   1389         f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
   1390         " generate arguments will also show up in this list)"
   1391     )

ValueError: The following `model_kwargs` are not used by the model: ['task_ids'] (note: typos in the generate arguments will also show up in this list)

Additional Notes

  • No problem using AutoAdapterModel.

@calpt calpt mentioned this pull request Feb 21, 2025
11 tasks
@calpt calpt requested a review from hSterz February 21, 2025 15:25
Copy link
Member

@hSterz hSterz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good!

@calpt calpt merged commit acca075 into adapter-hub:main Feb 26, 2025
4 checks passed
@calpt calpt deleted the dev/full_context branch February 26, 2025 20:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Support for Passing task_ids to forward_context for Multi-Task Learning
3 participants