Skip to content

Commit

Permalink
Fix StoppingCriteriaSub parameters to be compatible with latest Trans…
Browse files Browse the repository at this point in the history
…formers (#215)
  • Loading branch information
kira-lin authored May 11, 2024
1 parent cc1556d commit a61d89f
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion llm_on_ray/inference/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self, stops=[], encounters=1):
super().__init__()
self.stops = stops

def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs):
for stop in self.stops:
length = 1 if len(stop.size()) == 0 else stop.size()[0]
if torch.all((stop == input_ids[0][-length:])).item():
Expand Down

0 comments on commit a61d89f

Please sign in to comment.