Skip to content

Commit

Permalink
fix: generate encoder decoder model
Browse files Browse the repository at this point in the history
  • Loading branch information
Francois Ledoyen authored and Francois Ledoyen committed Feb 24, 2025
1 parent 833f58f commit 62cb855
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions src/adapters/model_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,11 +1730,12 @@ def _prepare_encoder_decoder_kwargs_for_generation(
}
encoder_signature = set(inspect.signature(encoder.forward).parameters)
encoder_accepts_wildcard = "kwargs" in encoder_signature or "model_kwargs" in encoder_signature
forward_context_args = ["adapter_input_parallelized", "task_ids"]
if not encoder_accepts_wildcard:
encoder_kwargs = {
argument: value
for argument, value in encoder_kwargs.items()
if argument in encoder_signature or argument == "adapter_input_parallelized"
if argument in encoder_signature or argument in forward_context_args
}
encoder_kwargs["output_attentions"] = generation_config.output_attentions
encoder_kwargs["output_hidden_states"] = generation_config.output_hidden_states
Expand All @@ -1744,7 +1745,9 @@ def _prepare_encoder_decoder_kwargs_for_generation(
encoder_kwargs["return_dict"] = True
encoder_kwargs[model_input_name] = inputs_tensor
with ForwardContext(self, **encoder_kwargs):
encoder_kwargs.pop("adapter_input_parallelized", None) # This should not be passed to actual model
for arg_name in forward_context_args:
encoder_kwargs.pop(arg_name, None) # This should not be passed to actual model

model_kwargs["encoder_outputs"]: ModelOutput = encoder(**encoder_kwargs)

return model_kwargs
Expand Down

0 comments on commit 62cb855

Please sign in to comment.