Skip to content

Commit

Permalink
[V1][Spec Decode] Ngram Spec Decode (#12193)
Browse files Browse the repository at this point in the history
Signed-off-by: LiuXiaoxuanPKU <[email protected]>
  • Loading branch information
LiuXiaoxuanPKU authored Feb 16, 2025
1 parent 367cb8c commit 80f63a3
Show file tree
Hide file tree
Showing 21 changed files with 1,025 additions and 84 deletions.
202 changes: 196 additions & 6 deletions tests/v1/core/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@
from vllm.config import CacheConfig, ModelConfig, SchedulerConfig
from vllm.multimodal.inputs import MultiModalKwargs, PlaceholderRange
from vllm.sampling_params import SamplingParams
from vllm.v1.core.scheduler import Scheduler
from vllm.v1.core.scheduler import Scheduler, SchedulerOutput
from vllm.v1.outputs import ModelRunnerOutput
from vllm.v1.request import Request, RequestStatus

EOS_TOKEN_ID = 50256


def create_scheduler(
model: str = "facebook/opt-125m",
Expand Down Expand Up @@ -38,6 +40,7 @@ def create_scheduler(
return Scheduler(scheduler_config,
model_config,
cache_config,
speculative_config=None,
lora_config=None,
log_stats=True)

Expand All @@ -46,8 +49,12 @@ def create_requests(
num_requests: int,
num_tokens: int = 10,
mm_positions: Optional[List[PlaceholderRange]] = None,
max_tokens: int = 16,
stop_token_ids: Optional[List[int]] = None,
):
sampling_params = SamplingParams()
sampling_params = SamplingParams(ignore_eos=False,
max_tokens=max_tokens,
stop_token_ids=stop_token_ids)
requests = []
for i in range(num_requests):
if mm_positions is not None:
Expand All @@ -64,7 +71,7 @@ def create_requests(
multi_modal_inputs=mm_inputs,
multi_modal_placeholders=mm_position,
multi_modal_hashes=None,
eos_token_id=None,
eos_token_id=EOS_TOKEN_ID,
arrival_time=0,
)
requests.append(request)
Expand Down Expand Up @@ -195,7 +202,7 @@ def test_schedule_partial_requests():
model_runner_output = ModelRunnerOutput(
req_ids=[request.request_id for request in requests],
req_id_to_index=req_to_index,
sampled_token_ids=[0] * len(requests),
sampled_token_ids=[[0] for _ in range(len(requests))],
logprobs=None,
prompt_logprobs_dict={},
)
Expand All @@ -215,6 +222,189 @@ def test_schedule_partial_requests():
assert requests[2].request_id not in output.num_scheduled_tokens


def test_stop_via_update_from_output():
"""Test stopping behavior through update_from_output"""
scheduler = create_scheduler()

# Test case 1: Stop on EOS token
requests = create_requests(num_requests=2, max_tokens=10)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)

scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 1,
requests[1].request_id: 2
},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [],
requests[1].request_id: [10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])

model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[EOS_TOKEN_ID],
[10,
11]], # First request hits EOS, second continues
logprobs=None,
prompt_logprobs_dict={})

scheduler.update_from_output(scheduler_output, model_output)

# Verify first request stopped, second continues
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID]
assert list(requests[1].output_token_ids) == [10, 11]

# Test case 2: Stop on custom stop token
scheduler = create_scheduler()
requests = create_requests(num_requests=2,
max_tokens=10,
stop_token_ids=[42, 43])
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)

scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 2
},
total_num_scheduled_tokens=5,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 42],
requests[1].request_id: [13]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])

model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 42, 12],
[13, 14]], # First request hits stop token
logprobs=None,
prompt_logprobs_dict={})

scheduler.update_from_output(scheduler_output, model_output)

# Verify first request stopped on custom token
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_STOPPED
assert requests[0].stop_reason == 42
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 42]
assert list(requests[1].output_token_ids) == [13, 14]

# Test case 3: Stop on max tokens
scheduler = create_scheduler()
requests = create_requests(num_requests=2, max_tokens=2)
for req in requests:
req.num_computed_tokens = req.num_tokens
scheduler.requests[req.request_id] = req
scheduler.running.append(req)
scheduler.scheduled_req_ids.add(req.request_id)

scheduler_output = SchedulerOutput(scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={
requests[0].request_id: 3,
requests[1].request_id: 1
},
total_num_scheduled_tokens=4,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [10, 11],
requests[1].request_id: []
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])

model_output = ModelRunnerOutput(
req_ids=[req.request_id for req in requests],
req_id_to_index={
req.request_id: i
for i, req in enumerate(requests)
},
sampled_token_ids=[[10, 11, 12],
[13]], # First request exceeds max_tokens
logprobs=None,
prompt_logprobs_dict={})

scheduler.update_from_output(scheduler_output, model_output)

# Verify first request stopped due to length
assert len(scheduler.running) == 1
assert scheduler.running[0].request_id == requests[1].request_id
assert requests[0].status == RequestStatus.FINISHED_LENGTH_CAPPED
assert requests[0].request_id in scheduler.finished_req_ids
assert list(requests[0].output_token_ids) == [10, 11
] # Truncated to max_tokens
assert list(requests[1].output_token_ids) == [13]

# Test case 4: Ignore EOS flag
scheduler = create_scheduler()
requests = create_requests(num_requests=1, max_tokens=10)
requests[0].sampling_params.ignore_eos = True
requests[0].num_computed_tokens = requests[0].num_tokens
scheduler.requests[requests[0].request_id] = requests[0]
scheduler.running.append(requests[0])
scheduler.scheduled_req_ids.add(requests[0].request_id)

scheduler_output = SchedulerOutput(
scheduled_new_reqs=[],
scheduled_cached_reqs=[],
num_scheduled_tokens={requests[0].request_id: 3},
total_num_scheduled_tokens=3,
scheduled_encoder_inputs={},
scheduled_spec_decode_tokens={
requests[0].request_id: [EOS_TOKEN_ID, 10]
},
num_common_prefix_blocks=0,
finished_req_ids=set(),
free_encoder_input_ids=[])

model_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[[EOS_TOKEN_ID, 10, 11]],
logprobs=None,
prompt_logprobs_dict={})

scheduler.update_from_output(scheduler_output, model_output)

# Verify request continues past EOS
assert len(scheduler.running) == 1
assert not requests[0].is_finished()
assert list(requests[0].output_token_ids) == [EOS_TOKEN_ID, 10, 11]


def test_schedule_concurrent_batches():
scheduler = create_scheduler(
max_num_batched_tokens=1024,
Expand Down Expand Up @@ -243,7 +433,7 @@ def test_schedule_concurrent_batches():
model_runner_output = ModelRunnerOutput(
req_ids=[requests[0].request_id],
req_id_to_index={requests[0].request_id: 0},
sampled_token_ids=[0],
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
)
Expand All @@ -259,7 +449,7 @@ def test_schedule_concurrent_batches():
model_runner_output = ModelRunnerOutput(
req_ids=[requests[1].request_id],
req_id_to_index={requests[1].request_id: 0},
sampled_token_ids=[0],
sampled_token_ids=[[0]],
logprobs=None,
prompt_logprobs_dict={},
)
Expand Down
49 changes: 49 additions & 0 deletions tests/v1/e2e/test_ngram_spec_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# SPDX-License-Identifier: Apache-2.0
import pytest

from vllm import LLM, SamplingParams


@pytest.fixture
def test_prompts():
return [
"Can you repeat the sentence ten times, this is a sentence.",
"Can you repeat the sentence ten times, this is a test.",
]


@pytest.fixture
def sampling_config():
# Only support greedy for now
return SamplingParams(temperature=0, max_tokens=30, ignore_eos=False)


@pytest.fixture
def model_name():
return "meta-llama/Meta-Llama-3-8B-Instruct"


def test_ngram_correctness(monkeypatch, test_prompts, sampling_config,
model_name):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")

ref_llm = LLM(model=model_name)
ref_outputs = ref_llm.generate(test_prompts, sampling_config)
del ref_llm

spec_llm = LLM(model=model_name,
speculative_model='[ngram]',
ngram_prompt_lookup_max=5,
ngram_prompt_lookup_min=3,
num_speculative_tokens=3)
spec_outputs = spec_llm.generate(test_prompts, sampling_config)
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
assert ref_output.outputs[0].text == spec_output.outputs[0].text, \
(f"ref_output: {ref_output.outputs[0].text},"
f"spec_output: {spec_output.outputs[0].text}")
del spec_llm
Loading

0 comments on commit 80f63a3

Please sign in to comment.