Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extend sampler tests #200

Merged
merged 34 commits into from
Feb 22, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
c21e2ea
Add test cases for temperature
Ailurus1 Feb 8, 2024
7e0a64f
Fix temperature parameter verification
Ailurus1 Feb 8, 2024
e755eb7
Extend tests on temperature
Ailurus1 Feb 9, 2024
0bd3c4c
Fix import
Ailurus1 Feb 9, 2024
20c7da5
Fix test for logprobs after rebase
Ailurus1 Feb 9, 2024
1aa769f
Add tests for repetition penalty
Ailurus1 Feb 9, 2024
5b5e256
Refactor test for penalties
Ailurus1 Feb 9, 2024
12ded7d
Fix types + update tests for logprobs after rebase
Ailurus1 Feb 13, 2024
f4681c1
Fix type of error in tests
Ailurus1 Feb 13, 2024
38fb95d
Fix tests for temperature
Ailurus1 Feb 14, 2024
f26710e
Update test for penalties
Ailurus1 Feb 14, 2024
3340a89
Extend tests for top_p, top_k
Ailurus1 Feb 14, 2024
0accd5e
Remove irrelevant todo
Ailurus1 Feb 14, 2024
1559d18
Extend test for logit_bias
Ailurus1 Feb 14, 2024
5c2c3a5
Remove underlines to execute with pytest
Ailurus1 Feb 14, 2024
869511e
Add test for num_sequences
Ailurus1 Feb 14, 2024
792f51d
format + lint
Ailurus1 Feb 14, 2024
a05f31f
Upstream new changes
Ailurus1 Feb 15, 2024
c0d3bd2
Add test for inspecting behaviour of logprobs depending on temperature
Ailurus1 Feb 15, 2024
b7b21be
Merge new changes
Ailurus1 Feb 16, 2024
4ef3dcb
Test with mixed greedy and random sampling requests + fix after merge
Ailurus1 Feb 16, 2024
cdeb1c5
Fix get_sampling_state
Ailurus1 Feb 16, 2024
b72a2a9
Redesign test penalties
Ailurus1 Feb 19, 2024
29184f7
Merge changes from batch-serving
Ailurus1 Feb 19, 2024
20d28e0
Set top_k as vocab_size when -1 + simplify test for penalties
Ailurus1 Feb 19, 2024
2b28e33
Add pytest parametrization
Ailurus1 Feb 19, 2024
48300f0
Update mixed test
Ailurus1 Feb 19, 2024
0409978
Skip broken test
Ailurus1 Feb 19, 2024
89b6999
Remove debug print
Ailurus1 Feb 19, 2024
9c53ca4
Corrections according to review comments + add simple test for logit_…
Ailurus1 Feb 21, 2024
3b3548f
Fix indices in logit_bias since it starts from 0
Ailurus1 Feb 21, 2024
7949db5
Update test for penalties
Ailurus1 Feb 21, 2024
bf0c848
Add greedy sampling case in test for penalties
Ailurus1 Feb 21, 2024
a4bda30
Remove debug lines
Ailurus1 Feb 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions serve/mlc_serve/engine/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,8 @@ def __post_init__(self):
self._verify_greedy_sampling()
if not self.logprobs:
self.top_logprobs = 0
if self.top_k == -1:
self.top_k = self.vocab_size

def verify(self) -> None:
if not -2.0 <= self.presence_penalty <= 2.0:
Expand All @@ -99,15 +101,15 @@ def verify(self) -> None:
"frequency_penalty must be in [-2, 2], got "
f"{self.frequency_penalty}."
)
if self.temperature < 0.0:
if not 0.0 <= self.temperature <= 2.0:
raise ValueError(
f"temperature must be non-negative, got {self.temperature}."
f"temperature must be in [0, 2], got {self.temperature}."
)
if not 0.0 < self.top_p <= 1.0:
raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.")

if not isinstance(self.top_k, int):
raise ValueError(f"top_k must be integer.")
raise TypeError(f"top_k must be integer.")

if self.top_k < -1 or self.top_k == 0:
raise ValueError(
Expand All @@ -132,6 +134,10 @@ def verify(self) -> None:
raise ValueError(
f"top_logprobs must be between 0 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}."
)
if not isinstance(self.top_logprobs, int):
raise TypeError(
"top_logprobs must be integer"
)

def _verify_greedy_sampling(self) -> None:
if self.top_p < 1.0 - _SAMPLING_EPS:
Expand Down
4 changes: 2 additions & 2 deletions serve/mlc_serve/model/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,8 +540,8 @@ def _is_safe_to_sample(prob_like):
assert sampling_state.sampling_params[batch_idx].logprobs
top_k = sampling_state.sampling_params[batch_idx].top_logprobs
logprob_infos[batch_idx] = RawLogprobsInfo(
current_token_id=next_token,
current_logprob=logprobs[batch_idx][next_token],
current_token_id=int(next_token),
current_logprob=float(logprobs[batch_idx][next_token]),
sunggg marked this conversation as resolved.
Show resolved Hide resolved
top_token_ids=top_tokens[idx][:top_k],
top_logprobs=top_logprobs[idx][:top_k],
)
Expand Down
87 changes: 86 additions & 1 deletion serve/tests/unittest/test_engine_with_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args, create_mlc_engine
import random
from pydantic import BaseModel
from typing import List
from typing import List, Callable


def create_request(
Expand All @@ -22,6 +22,7 @@ def create_request(
pre_pen,
max_tokens,
stop,
num_sequences=1,
ignore_eos=False,
top_logprobs=0,
logprobs=False,
Expand All @@ -41,6 +42,7 @@ def create_request(
json_schema=json_schema,
),
stopping_criteria=StoppingCriteria(max_tokens=max_tokens, stop_sequences=stop),
num_sequences=num_sequences,
debug_options=DebugOptions(ignore_eos=ignore_eos),
)

Expand Down Expand Up @@ -257,6 +259,46 @@ def _test_logprobs(
)
generated[int(res.request_id)] += seq.delta

# If temperature is increasing then difference between
# boundaries of range of top logprobs in response must decrease
temperatures = [0.2, 1.1, 2.0]
mean_bounds_diff = [0 for _ in range(num_requests * len(temperatures))]
for idx, temp in enumerate(temperatures):
requests = [
create_request(
idx=str(n),
prompt=random.choice(prompts),
temp=temp,
freq_pen=0,
pre_pen=0,
max_tokens=300,
stop=None,
ignore_eos=True,
logprobs=True,
top_logprobs=5
)
for n in range(num_requests)
]
engine.add(requests)

while engine.has_pending_requests():
results = engine.step()
for res in results.outputs:
seq = res.sequences[0]
req = requests[int(res.request_id)]

if not seq.is_finished:
mean_bounds_diff[idx * num_requests + int(res.request_id)] += \
seq.logprob_info[0].top_logprobs[0].logprob \
- seq.logprob_info[0].top_logprobs[4].logprob
else:
mean_bounds_diff[idx * num_requests + int(res.request_id)] /= seq.num_generated_tokens

for num_req_batch in range(num_requests):
for idx in range(1, len(temperatures)):
assert mean_bounds_diff[idx * num_requests + num_req_batch] < \
mean_bounds_diff[(idx - 1) * num_requests + num_req_batch]


def _test_logprobs_mixed_requests(
engine,
Expand Down Expand Up @@ -301,6 +343,48 @@ def _test_logprobs_mixed_requests(
assert len(seq.logprob_info) == 0
generated[int(res.request_id)] += seq.delta

def _test_num_sequences(
engine,
num_requests=5,
):
prompt = "Write a merge sort program in Python."
requests = []
num_sequences = [2 * i for i in range(1, num_requests + 1)]
for n, num_seq in enumerate(num_sequences):
requests.append(
create_request(
idx=str(n),
prompt=prompt,
temp=0.6,
freq_pen=0,
pre_pen=0,
stop=None,
max_tokens=300,
ignore_eos=False,
num_sequences=num_seq
)
)
engine.add(requests)

generated = [[""] * num_seq for _, num_seq in zip(range(num_requests), num_sequences)]
unique_sequences = [set() for _ in range(num_requests)]
while engine.has_pending_requests():
results = engine.step()
for idx, res in enumerate(results.outputs):
assert len(res.sequences) == num_sequences[idx]
for seq_id, seq in enumerate(res.sequences):
req_id = int(res.request_id)

if seq.delta:
generated[int(req_id)][seq_id] += seq.delta

if seq.is_finished:
unique_sequences[req_id].add(generated[req_id][seq_id])

for idx, response in enumerate(unique_sequences):
assert num_sequences[idx] == len(response)



# These three models are used in _test_json_mode
class France(BaseModel):
Expand Down Expand Up @@ -407,6 +491,7 @@ def _test_json_mode(
# _test_stop(staging_engine)
_test_logprobs(staging_engine)
_test_logprobs_mixed_requests(staging_engine)
_test_num_sequences(staging_engine)
_test_json_mode(staging_engine)
# These tests are broken since we are now imposing no length limit
# if max_tokens = None. The tests do not finish in a reasonable time.
Expand Down
Loading
Loading