Skip to content

Commit

Permalink
change order of keys
Browse files Browse the repository at this point in the history
  • Loading branch information
idanshen authored Jun 17, 2024
1 parent d1b4efb commit d072b34
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions trl/models/modeling_value_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def forward(
self,
input_ids=None,
past_key_values=None,
return_past_key_values=False,
attention_mask=None,
return_past_key_values=False,
**kwargs,
):
r"""
Expand All @@ -156,11 +156,11 @@ def forward(
past_key_values (`tuple(tuple(torch.FloatTensor))`, `optional`):
Contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
(see `past_key_values` input) to speed up sequential decoding.
return_past_key_values (bool): A flag indicating if the computed hidden-states should be returned.
attention_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, `optional`):
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
return_past_key_values (bool): A flag indicating if the computed hidden-states should be returned.
kwargs (`dict`, `optional`):
Additional keyword arguments, that are passed to the wrapped model.
"""
Expand Down Expand Up @@ -410,8 +410,8 @@ def forward(
self,
input_ids=None,
past_key_values=None,
return_past_key_values=False,
attention_mask=None,
return_past_key_values=False,
**kwargs,
):
kwargs["past_key_values"] = past_key_values
Expand Down

0 comments on commit d072b34

Please sign in to comment.