Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat repeat stoper #3

Merged
merged 2 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 79 additions & 0 deletions vllm/engine/output_processor/stop_checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,16 @@ def __init__(self, max_model_len: int,
self._max_model_len = max_model_len
self.get_tokenizer_for_seq = get_tokenizer_for_seq

# the position to start checking for repetition
self.repeat_start_from = 0

# the number of tokens repeated
self.repeated_count = 0
# the gap between the repeated tokens
self.repeated_gap = 0
# the repeated ngram that we already generated
self.repeated_total = 0

def _get_max_model_len(self, lora_req: Optional[LoRARequest]):
if lora_req and lora_req.long_lora_max_len:
return lora_req.long_lora_max_len
Expand Down Expand Up @@ -88,6 +98,75 @@ def maybe_stop_sequence(
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED
return

# Check if the last ngram is repeated in the output text.
last_token = seq.output_text[-new_char_count:]
# start checking for repetition after the first 32 tokens
if seq.get_output_len() > 32 and self.check_ngram_repetition(
seq, sampling_params, last_token):
seq.status = SequenceStatus.FINISHED_REPEATED
return

def check_ngram_repetition(self, seq: Sequence,
sampling_params: SamplingParams,
last_token: str) -> bool:
"""Check if the last ngram is repeated in the output text.
"""

is_done = False
output_ids = seq.get_output_token_ids()
last_token_id = seq.get_last_token_id()
output_len = seq.get_output_len()

repeated_at = None
repeated_gap = None

for i, token in enumerate(output_ids[self.repeat_start_from:-1]):
if token == last_token_id:
repeated_at = self.repeat_start_from + i
repeated_gap = output_len - repeated_at

if repeated_at is not None:
self.repeated_count += 1
# token_str = self.tokenizer.convert_ids_to_tokens([last_token])[0]
# print(
# f"\n==> token ({last_token}) at {output_len}\n"
# f"==> repeat_at: {repeated_at}\n"
# f"==> repeated_count: {self.repeated_count}\n"
# f"==> repeated_gap: {repeated_gap}\n"
# f"==> repeate_start_from: {self.repeat_start_from}"
# )

self.repeat_start_from = repeated_at

if repeated_at is None or repeated_gap != self.repeated_gap:
self.repeated_count = 0
self.repeated_gap = 0
self.repeated_total = 0

if repeated_gap is not None:
self.repeated_gap = repeated_gap

if self.repeated_count == self.repeated_gap and self.repeated_gap:
self.repeated_total += 1
self.repeated_count = 0

# print(f"==> repeated_total: {self.repeated_total}")

repeate_ngram_size = self.repeated_gap
# print(f'==> repeate_ngram_size: {repeate_ngram_size}')

if repeate_ngram_size == 1:
# single token repetition
is_done = self.repeated_total > 64
elif repeate_ngram_size > 64:
# paragraph repetition
is_done = self.repeated_total >= 4
else:
# short ngram repetition?
is_done = self.repeated_total >= 8

return is_done

@staticmethod
def check_stop_strings(
output_text: str,
Expand Down
5 changes: 5 additions & 0 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class SequenceStatus(enum.IntEnum):
FINISHED_LENGTH_CAPPED = 4
FINISHED_ABORTED = 5
FINISHED_IGNORED = 6
FINISHED_REPEATED = 7

@staticmethod
def is_finished(status: "SequenceStatus") -> bool:
Expand All @@ -83,6 +84,10 @@ def get_finished_reason(status: "SequenceStatus") -> Union[str, None]:
# are longer than the model's length cap. Therefore, the stop
# reason should also be "length" as in OpenAI API.
finish_reason = "length"
elif status == SequenceStatus.FINISHED_REPEATED:
# The repeated sequences are the generated sequences appeared
# ngram repeated.
finish_reason = "ngram_repeat"
else:
finish_reason = None
return finish_reason
Expand Down
Loading