diff --git a/llm_on_ray/inference/utils.py b/llm_on_ray/inference/utils.py index 91e311088..2d1a4d878 100644 --- a/llm_on_ray/inference/utils.py +++ b/llm_on_ray/inference/utils.py @@ -95,7 +95,7 @@ def __init__(self, stops=[], encounters=1): super().__init__() self.stops = stops - def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, **kwargs): for stop in self.stops: length = 1 if len(stop.size()) == 0 else stop.size()[0] if torch.all((stop == input_ids[0][-length:])).item():