Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for returning past_key_values from the model #1742

Merged
merged 6 commits into from
Jun 20, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@ def forward(
input_ids=None,
past_key_values=None,
attention_mask=None,
return_past_key_values=False,
**kwargs,
):
r"""
Expand All @@ -159,6 +160,7 @@ def forward(
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
return_past_key_values (bool): A flag indicating if the computed hidden-states should be returned.
kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the wrapped model.
"""
Expand Down Expand Up @@ -187,7 +189,10 @@ def forward(
if lm_logits.dtype != torch.float32:
lm_logits = lm_logits.float()

return (lm_logits, loss, value)
if return_past_key_values:
return (lm_logits, loss, value, base_model_output.past_key_values)
else:
return (lm_logits, loss, value)

def generate(self, *args, **kwargs):
r"""
Expand Down Expand Up @@ -406,6 +411,7 @@ def forward(
input_ids=None,
past_key_values=None,
attention_mask=None,
return_past_key_values=False,
**kwargs,
):
kwargs["past_key_values"] = past_key_values
Expand All @@ -429,7 +435,10 @@ def forward(
if lm_logits.dtype != torch.float32:
lm_logits = lm_logits.float()

return (lm_logits, loss, value)
if return_past_key_values:
return (lm_logits, loss, value, base_model_output.past_key_values)
else:
return (lm_logits, loss, value)

def generate(self, *args, **kwargs):
r"""
Expand Down
Loading