Skip to content

Commit

Permalink
fix: add ngram repetition stop checker
Browse files Browse the repository at this point in the history
  • Loading branch information
numb3r3 committed Jan 11, 2025
1 parent 7a3a83e commit ca2c0d4
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 0 deletions.
78 changes: 78 additions & 0 deletions vllm/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ def __init__(self, max_model_len: int,
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq

# the position to start checking for repetition
self.repeat_start_from = 0

# the number of tokens repeated
self.repeated_count = 0
# the gap between the repeated tokens
self.repeated_gap = 0
# the repeated ngram that we already generated
self.repeated_total = 0

def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
if lora_req and lora_req.long_lora_max_len:
return lora_req.long_lora_max_len
Expand Down Expand Up @@ -88,6 +98,74 @@ def maybe_stop_sequence(
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the last ngram is repeated in the output text.
last_token = seq.output_text[-new_char_count:]
# start checking for repetition after the first 32 tokens
if seq.get_output_len() > 32 and self.check_ngram_repetition(seq, sampling_params, last_token):
seq.status = SequenceStatus.FINISHED_REPEATED
return

def check_ngram_repetition(self, seq: Sequence, sampling_params: SamplingParams, last_token: str) -> bool:
"""Check if the last ngram is repeated in the output text.
"""

is_done = False
output_ids = seq.get_output_token_ids()
last_token_id = seq.get_last_token_id()
output_len = seq.get_output_len()

repeated_at = None
repeated_gap = None

for i, token in enumerate(output_ids[self.repeat_start_from:-1]):
if token == last_token_id:
repeated_at = self.repeat_start_from + i
repeated_gap = output_len - repeated_at

if repeated_at is not None:
self.repeated_count += 1
# token_str = self.tokenizer.convert_ids_to_tokens([last_token])[0]
# print(
# f"\n==> token ({last_token}) at {output_len}\n"
# f"==> repeat_at: {repeated_at}\n"
# f"==> repeated_count: {self.repeated_count}\n"
# f"==> repeated_gap: {repeated_gap}\n"
# f"==> repeate_start_from: {self.repeat_start_from}"
# )

self.repeat_start_from = repeated_at

if repeated_at is None or repeated_gap != self.repeated_gap:
self.repeated_count = 0
self.repeated_gap = 0
self.repeated_total = 0

if repeated_gap is not None:
self.repeated_gap = repeated_gap

if self.repeated_count == self.repeated_gap and self.repeated_gap:
self.repeated_total += 1
self.repeated_count = 0

# print(f"==> repeated_total: {self.repeated_total}")

repeate_ngram_size = self.repeated_gap
# print(f'==> repeate_ngram_size: {repeate_ngram_size}')

if repeate_ngram_size == 1:
# single token repetition
is_done = self.repeated_total > 64
elif repeate_ngram_size > 64:
# paragraph repetition
is_done = self.repeated_total >= 4
else:
# short ngram repetition?
is_done = self.repeated_total >= 8

return is_done



@staticmethod
def check_stop_strings(
output_text: str,
Expand Down
5 changes: 5 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class SequenceStatus(enum.IntEnum):
FINISHED_LENGTH_CAPPED = 4
FINISHED_ABORTED = 5
FINISHED_IGNORED = 6
FINISHED_REPEATED = 7

@staticmethod
def is_finished(status: "SequenceStatus") -> bool:
Expand All @@ -83,6 +84,10 @@ def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
finish_reason = "length"
elif status == SequenceStatus.FINISHED_REPEATED:
# The repeated sequences are the generated sequences appeared
# ngram repeated.
finish_reason = "ngram_repeat"
else:
finish_reason = None
return finish_reason
Expand Down

0 comments on commit ca2c0d4

Please sign in to comment.