From ca2c0d4b939b7ee335ea64837d4546f58a35387e Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Fri, 30 Aug 2024 10:31:22 +0800 Subject: [PATCH 1/2] fix: add ngram repetition stop checker --- vllm/engine/output_processor/stop_checker.py | 78 ++++++++++++++++++++ vllm/sequence.py | 5 ++ 2 files changed, 83 insertions(+) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 4b701f81504bb..316752a2abb87 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -19,6 +19,16 @@ def __init__(self, max_model_len: int, self._max_model_len = max_model_len self.get_tokenizer_for_seq = get_tokenizer_for_seq + # the position to start checking for repetition + self.repeat_start_from = 0 + + # the number of tokens repeated + self.repeated_count = 0 + # the gap between the repeated tokens + self.repeated_gap = 0 + # the repeated ngram that we already generated + self.repeated_total = 0 + def _get_max_model_len(self, lora_req: Optional[LoRARequest]): if lora_req and lora_req.long_lora_max_len: return lora_req.long_lora_max_len @@ -88,6 +98,74 @@ def maybe_stop_sequence( seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED return + # Check if the last ngram is repeated in the output text. + last_token = seq.output_text[-new_char_count:] + # start checking for repetition after the first 32 tokens + if seq.get_output_len() > 32 and self.check_ngram_repetition(seq, sampling_params, last_token): + seq.status = SequenceStatus.FINISHED_REPEATED + return + + def check_ngram_repetition(self, seq: Sequence, sampling_params: SamplingParams, last_token: str) -> bool: + """Check if the last ngram is repeated in the output text. + """ + + is_done = False + output_ids = seq.get_output_token_ids() + last_token_id = seq.get_last_token_id() + output_len = seq.get_output_len() + + repeated_at = None + repeated_gap = None + + for i, token in enumerate(output_ids[self.repeat_start_from:-1]): + if token == last_token_id: + repeated_at = self.repeat_start_from + i + repeated_gap = output_len - repeated_at + + if repeated_at is not None: + self.repeated_count += 1 + # token_str = self.tokenizer.convert_ids_to_tokens([last_token])[0] + # print( + # f"\n==> token ({last_token}) at {output_len}\n" + # f"==> repeat_at: {repeated_at}\n" + # f"==> repeated_count: {self.repeated_count}\n" + # f"==> repeated_gap: {repeated_gap}\n" + # f"==> repeate_start_from: {self.repeat_start_from}" + # ) + + self.repeat_start_from = repeated_at + + if repeated_at is None or repeated_gap != self.repeated_gap: + self.repeated_count = 0 + self.repeated_gap = 0 + self.repeated_total = 0 + + if repeated_gap is not None: + self.repeated_gap = repeated_gap + + if self.repeated_count == self.repeated_gap and self.repeated_gap: + self.repeated_total += 1 + self.repeated_count = 0 + + # print(f"==> repeated_total: {self.repeated_total}") + + repeate_ngram_size = self.repeated_gap + # print(f'==> repeate_ngram_size: {repeate_ngram_size}') + + if repeate_ngram_size == 1: + # single token repetition + is_done = self.repeated_total > 64 + elif repeate_ngram_size > 64: + # paragraph repetition + is_done = self.repeated_total >= 4 + else: + # short ngram repetition? + is_done = self.repeated_total >= 8 + + return is_done + + + @staticmethod def check_stop_strings( output_text: str, diff --git a/vllm/sequence.py b/vllm/sequence.py index 5857f656dfc10..1e587b1b29d8a 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -65,6 +65,7 @@ class SequenceStatus(enum.IntEnum): FINISHED_LENGTH_CAPPED = 4 FINISHED_ABORTED = 5 FINISHED_IGNORED = 6 + FINISHED_REPEATED = 7 @staticmethod def is_finished(status: "SequenceStatus") -> bool: @@ -83,6 +84,10 @@ def get_finished_reason(status: "SequenceStatus") -> Union[str, None]: # are longer than the model's length cap. Therefore, the stop # reason should also be "length" as in OpenAI API. finish_reason = "length" + elif status == SequenceStatus.FINISHED_REPEATED: + # The repeated sequences are the generated sequences appeared + # ngram repeated. + finish_reason = "ngram_repeat" else: finish_reason = None return finish_reason From f0c20e16e40dc8c019a3c4988de69fb000e8d557 Mon Sep 17 00:00:00 2001 From: numb3r3 Date: Fri, 30 Aug 2024 10:42:51 +0800 Subject: [PATCH 2/2] fix: format codes --- vllm/engine/output_processor/stop_checker.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/vllm/engine/output_processor/stop_checker.py b/vllm/engine/output_processor/stop_checker.py index 316752a2abb87..a43c79b4ea8c9 100644 --- a/vllm/engine/output_processor/stop_checker.py +++ b/vllm/engine/output_processor/stop_checker.py @@ -101,11 +101,14 @@ def maybe_stop_sequence( # Check if the last ngram is repeated in the output text. last_token = seq.output_text[-new_char_count:] # start checking for repetition after the first 32 tokens - if seq.get_output_len() > 32 and self.check_ngram_repetition(seq, sampling_params, last_token): + if seq.get_output_len() > 32 and self.check_ngram_repetition( + seq, sampling_params, last_token): seq.status = SequenceStatus.FINISHED_REPEATED return - def check_ngram_repetition(self, seq: Sequence, sampling_params: SamplingParams, last_token: str) -> bool: + def check_ngram_repetition(self, seq: Sequence, + sampling_params: SamplingParams, + last_token: str) -> bool: """Check if the last ngram is repeated in the output text. """ @@ -164,8 +167,6 @@ def check_ngram_repetition(self, seq: Sequence, sampling_params: SamplingParams, return is_done - - @staticmethod def check_stop_strings( output_text: str,