diff --git a/serve/mlc_serve/engine/sampling_params.py b/serve/mlc_serve/engine/sampling_params.py index f2638c7fa4..fba1e01028 100644 --- a/serve/mlc_serve/engine/sampling_params.py +++ b/serve/mlc_serve/engine/sampling_params.py @@ -88,6 +88,8 @@ def __post_init__(self): self._verify_greedy_sampling() if not self.logprobs: self.top_logprobs = 0 + if self.top_k == -1: + self.top_k = self.vocab_size def verify(self) -> None: if not -2.0 <= self.presence_penalty <= 2.0: @@ -99,15 +101,15 @@ def verify(self) -> None: "frequency_penalty must be in [-2, 2], got " f"{self.frequency_penalty}." ) - if self.temperature < 0.0: + if not 0.0 <= self.temperature <= 2.0: raise ValueError( - f"temperature must be non-negative, got {self.temperature}." + f"temperature must be in [0, 2], got {self.temperature}." ) if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") if not isinstance(self.top_k, int): - raise ValueError(f"top_k must be integer.") + raise TypeError(f"top_k must be integer.") if self.top_k < -1 or self.top_k == 0: raise ValueError( @@ -119,8 +121,10 @@ def verify(self) -> None: raise ValueError( f"logit bias must be in [-100, 100], got {bias} for token {token}." ) - if not 1 <= token <= self.vocab_size: - raise ValueError(f"token id must be in [1, vocab_size]") + if not isinstance(token, int): + raise ValueError(f"token id must be an integer") + if not 0 <= token < self.vocab_size: + raise ValueError(f"token id must be in [0, vocab_size)") if self.repetition_penalty <= 0: raise ValueError( @@ -132,6 +136,10 @@ def verify(self) -> None: raise ValueError( f"top_logprobs must be between 0 and {LOGPROB_TOP_K_MAX}, got {self.top_logprobs}." ) + if not isinstance(self.top_logprobs, int): + raise TypeError( + "top_logprobs must be integer" + ) def _verify_greedy_sampling(self) -> None: if self.top_p < 1.0 - _SAMPLING_EPS: diff --git a/serve/mlc_serve/model/sampler.py b/serve/mlc_serve/model/sampler.py index 10267238c4..31fc066513 100644 --- a/serve/mlc_serve/model/sampler.py +++ b/serve/mlc_serve/model/sampler.py @@ -172,8 +172,6 @@ def from_lists( device="cpu", pin_memory=True, ) - # Convert 1-based index to 0-based - logit_bias_indices -= 1 logit_bias_values = torch.tensor( list_logit_bias_values, dtype=dtype, @@ -540,8 +538,8 @@ def _is_safe_to_sample(prob_like): assert sampling_state.sampling_params[batch_idx].logprobs top_k = sampling_state.sampling_params[batch_idx].top_logprobs logprob_infos[batch_idx] = RawLogprobsInfo( - current_token_id=next_token, - current_logprob=logprobs[batch_idx][next_token], + current_token_id=int(next_token), + current_logprob=float(logprobs[batch_idx][next_token]), top_token_ids=top_tokens[idx][:top_k], top_logprobs=top_logprobs[idx][:top_k], ) diff --git a/serve/tests/unittest/test_engine_with_samplers.py b/serve/tests/unittest/test_engine_with_samplers.py index 18feed5a9d..543f876b67 100644 --- a/serve/tests/unittest/test_engine_with_samplers.py +++ b/serve/tests/unittest/test_engine_with_samplers.py @@ -11,7 +11,7 @@ from mlc_serve.utils import get_default_mlc_serve_argparser, postproc_mlc_serve_args, create_mlc_engine import random from pydantic import BaseModel -from typing import List +from typing import List, Callable def create_request( @@ -22,6 +22,7 @@ def create_request( pre_pen, max_tokens, stop, + num_sequences=1, ignore_eos=False, top_logprobs=0, logprobs=False, @@ -41,6 +42,7 @@ def create_request( json_schema=json_schema, ), stopping_criteria=StoppingCriteria(max_tokens=max_tokens, stop_sequences=stop), + num_sequences=num_sequences, debug_options=DebugOptions(ignore_eos=ignore_eos), ) @@ -210,6 +212,45 @@ def _test_stop( ) assert found == 1, f"{gen_txt!r}, matches: {found}" +def _test_logit_bias( + engine, + num_requests=10 +): + prompt = "Repeat only one of the following words: hi, hello" + requests = [] + for n in range(num_requests): + requests.append( + create_request( + idx=str(n), + prompt=prompt, + temp=0.8, + freq_pen=0, + pre_pen=0, + max_tokens=10, + stop="\n", + logit_bias={ + engine.tokenizer.encode("hi")[0]: -100.0, + engine.tokenizer.encode("Hi")[0]: -100.0 + } + ) + ) + + engine.add(requests) + generated = ["" for _ in range(num_requests)] + + while engine.has_pending_requests(): + results = engine.step() + for res in results.outputs: + assert len(res.sequences) == 1 + seq = res.sequences[0] + req_id = int(res.request_id) + + if seq.delta: + generated[int(res.request_id)] += seq.delta + + if seq.is_finished: + gen_txt = generated[req_id] + assert "hi" not in gen_txt and "Hi" not in gen_txt def _test_logprobs( engine, @@ -257,6 +298,46 @@ def _test_logprobs( ) generated[int(res.request_id)] += seq.delta + # If temperature is increasing then difference between + # boundaries of range of top logprobs in response must decrease + temperatures = [0.2, 1.1, 2.0] + mean_bounds_diff = [0 for _ in range(num_requests * len(temperatures))] + for idx, temp in enumerate(temperatures): + requests = [ + create_request( + idx=str(n), + prompt=random.choice(prompts), + temp=temp, + freq_pen=0, + pre_pen=0, + max_tokens=300, + stop=None, + ignore_eos=True, + logprobs=True, + top_logprobs=5 + ) + for n in range(num_requests) + ] + engine.add(requests) + + while engine.has_pending_requests(): + results = engine.step() + for res in results.outputs: + seq = res.sequences[0] + req = requests[int(res.request_id)] + + if not seq.is_finished: + mean_bounds_diff[idx * num_requests + int(res.request_id)] += \ + seq.logprob_info[0].top_logprobs[0].logprob \ + - seq.logprob_info[0].top_logprobs[4].logprob + else: + mean_bounds_diff[idx * num_requests + int(res.request_id)] /= seq.num_generated_tokens + + for num_req_batch in range(num_requests): + for idx in range(1, len(temperatures)): + assert mean_bounds_diff[idx * num_requests + num_req_batch] < \ + mean_bounds_diff[(idx - 1) * num_requests + num_req_batch] + def _test_logprobs_mixed_requests( engine, @@ -301,6 +382,48 @@ def _test_logprobs_mixed_requests( assert len(seq.logprob_info) == 0 generated[int(res.request_id)] += seq.delta +def _test_num_sequences( + engine, + num_requests=5, +): + prompt = "Write a merge sort program in Python." + requests = [] + num_sequences = [2 * i for i in range(1, num_requests + 1)] + for n, num_seq in enumerate(num_sequences): + requests.append( + create_request( + idx=str(n), + prompt=prompt, + temp=0.6, + freq_pen=0, + pre_pen=0, + stop=None, + max_tokens=300, + ignore_eos=False, + num_sequences=num_seq + ) + ) + engine.add(requests) + + generated = [[""] * num_seq for _, num_seq in zip(range(num_requests), num_sequences)] + unique_sequences = [set() for _ in range(num_requests)] + while engine.has_pending_requests(): + results = engine.step() + for idx, res in enumerate(results.outputs): + assert len(res.sequences) == num_sequences[idx] + for seq_id, seq in enumerate(res.sequences): + req_id = int(res.request_id) + + if seq.delta: + generated[int(req_id)][seq_id] += seq.delta + + if seq.is_finished: + unique_sequences[req_id].add(generated[req_id][seq_id]) + + for idx, response in enumerate(unique_sequences): + assert num_sequences[idx] == len(response) + + # These three models are used in _test_json_mode class France(BaseModel): @@ -407,6 +530,8 @@ def _test_json_mode( # _test_stop(staging_engine) _test_logprobs(staging_engine) _test_logprobs_mixed_requests(staging_engine) + _test_num_sequences(staging_engine) + _test_logit_bias(staging_engine) _test_json_mode(staging_engine) # These tests are broken since we are now imposing no length limit # if max_tokens = None. The tests do not finish in a reasonable time. @@ -422,6 +547,8 @@ def _test_json_mode( _test_stop(sync_engine) _test_logprobs(sync_engine) _test_logprobs_mixed_requests(sync_engine) + _test_num_sequences(sync_engine) + _test_logit_bias(sync_engine) _test_json_mode(sync_engine) # These tests are broken since we are now imposing no length limit # if max_tokens = None. The tests do not finish in a reasonable time. diff --git a/serve/tests/unittest/test_sampler.py b/serve/tests/unittest/test_sampler.py index eae3bf0ce3..d6fae49226 100644 --- a/serve/tests/unittest/test_sampler.py +++ b/serve/tests/unittest/test_sampler.py @@ -1,15 +1,17 @@ +import random +from itertools import product, permutations +from typing import List + import torch import pytest -from mlc_serve.model.sampler import SamplingState, adjust_logits +from mlc_serve.model.sampler import SamplingState, adjust_logits, sample, SamplingOutput from mlc_serve.engine import SamplingParams, SAMPLING_EPS -import random dtype = torch.float32 dev = "cuda" -vocab_size = 32000 -def get_sampling_state(sampling_params, past_output_tokens=None, prompt_masks=None): +def get_sampling_state(sampling_params, past_output_tokens=None, prompt_masks=None, vocab_size=32000): batch_size = len(sampling_params) if past_output_tokens is None: past_output_tokens = [[] for _ in range(batch_size)] @@ -31,136 +33,202 @@ def get_sampling_state(sampling_params, past_output_tokens=None, prompt_masks=No return sampling_state -def test_temperature(temp=0, batch_size=1): +def test_temperature_checker(): + # temperature must be in [0, 2] + get_sampling_state([SamplingParams(temperature=0.0)]) + get_sampling_state([SamplingParams(temperature=0.8)]) + get_sampling_state([SamplingParams(temperature=1.3)]) + get_sampling_state([SamplingParams(temperature=2.0)]) + + with pytest.raises(ValueError): + get_sampling_state([SamplingParams(temperature=-0.1)]) + + with pytest.raises(ValueError): + get_sampling_state([SamplingParams(temperature=2.1)]) + +@pytest.mark.parametrize("batch_size", [1, 4, 8, 12]) +def test_temperature(batch_size: int): + vocab_size = 32000 shape = (batch_size, vocab_size) logits = torch.rand(shape, dtype=dtype, device=dev) - sampling_param = SamplingParams(temperature=temp) + temperature = [0, 0.5, 1.0, 1.5, 2.0] + for batch_temp in permutations(temperature, batch_size): + sampling_params = [ + SamplingParams(temperature=val) + for val in batch_temp + ] + expected = [] + for idx, val in enumerate(batch_temp): + expected.append(logits[idx] / val if abs(val) > SAMPLING_EPS else logits[idx]) + sampling_state = get_sampling_state(sampling_params) + new_logits = adjust_logits(logits, sampling_state, vocab_size) + for idx, response in enumerate(new_logits): + assert torch.allclose(expected[idx], response) - sampling_state = get_sampling_state([sampling_param]) - expected = logits / temp if abs(temp) > SAMPLING_EPS else logits - new_logits = adjust_logits(logits, sampling_state, vocab_size) - assert torch.allclose(expected, new_logits) +def test_logit_bias_checker(): + # logit bias values must be [-100, 100] + # and indices in [0, vocab_size) + vocab_size = 32000 + get_sampling_state([SamplingParams(logit_bias={1: 100, 3: -100, 2: 2})]) + get_sampling_state([SamplingParams(logit_bias={34: 0, 23: -0.5})]) + get_sampling_state([SamplingParams(logit_bias={1: 10, 3: -10, vocab_size - 1: 2})]) + get_sampling_state([SamplingParams(logit_bias={})]) -def test_logit_bias_checker(): - # logit bias must be [-100, 100] with pytest.raises(ValueError): - logit_bias = {1: 2, 3: 105, 2: 2} - sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_state([sampling_param]) + get_sampling_state([SamplingParams(logit_bias={1: 2, 3: 105, 2: 2})]) with pytest.raises(ValueError): - logit_bias = {1: 99, 3: -101, 2: 2} - sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_state([sampling_param]) - - logit_bias = {1: 100, 3: -100, 2: 2} - sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_state([sampling_param]) - - # TODO(@team): it seems like the valid range is [1,vocab_size]. Double check. - logit_bias = {1: 10, 3: -10, vocab_size: 2} - sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_state([sampling_param]) + get_sampling_state([SamplingParams(logit_bias={1: 99, 3: -101, 2: 2})]) with pytest.raises(ValueError): - logit_bias = {0: 10, 3: -10} - sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_state([sampling_param]) + get_sampling_state([SamplingParams(logit_bias={1: 10, 3: -10, vocab_size: 2})]) with pytest.raises(ValueError): - logit_bias = {1: 10, 3: -10, vocab_size + 100: 2} - sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_state([sampling_param]) + get_sampling_state([SamplingParams(logit_bias={1: 10, 3: -10, vocab_size + 100: 2})]) with pytest.raises(ValueError): - logit_bias = {1: 10, -1: -10} - sampling_param = SamplingParams(logit_bias=logit_bias) - get_sampling_state([sampling_param]) + get_sampling_state([SamplingParams(logit_bias={1: 10, -1: -10})]) - -def test_logit_bias(): - # test single batch - batch_size = 1 +@pytest.mark.parametrize("batch_size", [1, 4]) +def test_logit_bias(batch_size: int): + vocab_size = 32000 shape = (batch_size, vocab_size) logits = torch.rand(shape, dtype=dtype, device=dev) - logit_bias = {1: -1, 3: 1, 2: 2} - sampling_param = SamplingParams(logit_bias=logit_bias) - sampling_state = get_sampling_state([sampling_param]) - + sampling_param = [{} for _ in range(batch_size)] + for logit_bias_combination in permutations( + product( + [0, 31999, 724, 223], + [100, -100, -12.5, 0.05] + ), + batch_size + ): + for num_batch in range(len(logit_bias_combination)): + logit_index, logit_bias = logit_bias_combination[num_batch] + sampling_param[num_batch].update({logit_index: logit_bias}) expected = torch.clone(logits) - for idx, val in logit_bias.items(): - expected[0][idx - 1] += val + for num_batch in range(batch_size): + for idx, val in sampling_param[num_batch].items(): + expected[num_batch][idx] += val + for idx, logit_bias in enumerate(sampling_param): + sampling_param[idx] = SamplingParams(logit_bias=logit_bias) + sampling_state = get_sampling_state(sampling_param) new_logits = adjust_logits(logits, sampling_state, vocab_size) assert torch.allclose(expected, new_logits) - # test multi-batch - batch_size = 3 - shape = (batch_size, vocab_size) - logits = torch.rand(shape, dtype=dtype, device=dev) - list_logit_bias = [{1: -1, 3: 1, 2: 2}, {4: 2, 5: 1}, {1: -10}] - sampling_params = [ - SamplingParams(logit_bias=logit_bias) for logit_bias in list_logit_bias - ] - sampling_state = get_sampling_state(sampling_params) - expected = torch.clone(logits) - for batch_size in range(batch_size): - logit_bias = list_logit_bias[batch_size] - for idx, val in logit_bias.items(): - expected[batch_size][idx - 1] += val - new_logits = adjust_logits(logits, sampling_state, vocab_size) - assert torch.allclose(expected, new_logits) +def test_penalties_checker(): + # repetition_penalty must be >0 + # frequency_penalty must be in [-2, 2] + # precense_penalty must be in [-2, 2] + # repetition_penalty + get_sampling_state( + [SamplingParams(repetition_penalty=0.1)], + ) -def test_penalties_checker(): - get_sampling_state([SamplingParams(presence_penalty=-1.0)]) - get_sampling_state([SamplingParams(frequency_penalty=-1.0)]) - get_sampling_state([SamplingParams(repetition_penalty=0.7)]) + get_sampling_state( + [SamplingParams(repetition_penalty=2.0)], + ) with pytest.raises(ValueError): - get_sampling_state([SamplingParams(presence_penalty=-2.1)]) + get_sampling_state( + [SamplingParams(repetition_penalty=0.0)], + ) + + with pytest.raises(ValueError): + get_sampling_state( + [SamplingParams(repetition_penalty=-2.0)], + ) + + # frequency_penalty + get_sampling_state([SamplingParams(frequency_penalty=-2.0)]) + get_sampling_state([SamplingParams(frequency_penalty=2.0)]) with pytest.raises(ValueError): get_sampling_state([SamplingParams(frequency_penalty=-2.1)]) with pytest.raises(ValueError): - get_sampling_state([SamplingParams(repetition_penalty=-2.1)]) + get_sampling_state([SamplingParams(frequency_penalty=2.1)]) + + # presence_penalty + get_sampling_state([SamplingParams(presence_penalty=-2.0)]) + get_sampling_state([SamplingParams(presence_penalty=2.0)]) + + with pytest.raises(ValueError): + get_sampling_state([SamplingParams(presence_penalty=-2.1)]) with pytest.raises(ValueError): get_sampling_state([SamplingParams(presence_penalty=2.1)]) + # combinations of penalties with valid values + get_sampling_state( + [SamplingParams(repetition_penalty=0.5, presence_penalty=0.5, frequency_penalty=0.0)], + ) + + # combinations of penalties with invalid values with pytest.raises(ValueError): - get_sampling_state([SamplingParams(frequency_penalty=2.1)]) + get_sampling_state( + [SamplingParams(repetition_penalty=-0.5, presence_penalty=0.5, frequency_penalty=0.0)], + ) + + with pytest.raises(ValueError): + get_sampling_state( + [SamplingParams(repetition_penalty=0.5, presence_penalty=2.5, frequency_penalty=0.0)], + ) + + with pytest.raises(ValueError): + get_sampling_state( + [SamplingParams(repetition_penalty=0.5, presence_penalty=0.5, frequency_penalty=-3.0)], + ) + + # penalties with valid values in multi-batch + get_sampling_state( + [ + SamplingParams(repetition_penalty=1.5), + SamplingParams(presence_penalty=0.5), + SamplingParams(frequency_penalty=0.0), + ], + ) + # penalties with invalid values in multi-batch with pytest.raises(ValueError): get_sampling_state( [ - SamplingParams(frequency_penalty=1.1), - SamplingParams(repetition_penalty=2.1), + SamplingParams(frequency_penalty=2.1), + SamplingParams(repetition_penalty=1.1), SamplingParams(presence_penalty=1.1), - SamplingParams(presence_penalty=3.1), - ] + SamplingParams(frequency_penalty=1.1), + ], ) + with pytest.raises(ValueError): + get_sampling_state( + [ + SamplingParams(frequency_penalty=1.1), + SamplingParams(repetition_penalty=1.1), + SamplingParams(presence_penalty=1.1), + SamplingParams(repetition_penalty=0.0), + ], + ) -def test_penalties(): - # TODO(vvchernov): Add test for repetition penalty - batch_size = 1 - shape = (batch_size, vocab_size) - logits = torch.rand(shape, dtype=dtype, device=dev) - presence_penalties = [0.8] - frequency_penalties = [0.3] - past_output_tokens = [[2, 2, 2, 3]] - prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool) - prompt_masks = [prompt_mask] * batch_size + with pytest.raises(ValueError): + get_sampling_state( + [ + SamplingParams(frequency_penalty=1.1), + SamplingParams(repetition_penalty=1.1), + SamplingParams(presence_penalty=1.1), + SamplingParams(presence_penalty=2.1), + ], + ) - def prepare_metadata(past_output_tokens): +@pytest.mark.parametrize("batch_size", [1, 3]) +def test_penalties(batch_size: int): + def _prepare_metadata(past_output_tokens, vocab_size): count_map = [] for past_output_tokens_per_req in past_output_tokens: - # TODO: Check if this is the right range - cnt = [0] * (vocab_size) + cnt = [0] * vocab_size for tok in past_output_tokens_per_req: cnt[tok] += 1 count_map.append(cnt) @@ -169,187 +237,283 @@ def prepare_metadata(past_output_tokens): mask_tensor = count_tensor > 0 return count_tensor, mask_tensor - count_map, mask = prepare_metadata(past_output_tokens) - def get_expected_result( - logits, count_map, mask, frequency_penalties, presence_penalties + def _get_expected_result( + logits, + count_map, + mask, + temperatures, + repetition_penalties, + presence_penalties, + frequency_penalties, ): expected = torch.clone(logits) for i in range(batch_size): + for j in range(len(expected[i])): + if mask[i][j]: + expected[i][j] *= 1 / repetition_penalties[i] if expected[i][j] > 0 else repetition_penalties[i] + temperature = 1.0 if temperatures[i] < SAMPLING_EPS else temperatures[i] expected[i] = ( - expected[i] + (expected[i] - count_map[i] * frequency_penalties[i] - - mask[i] * presence_penalties[i] + - mask[i] * presence_penalties[i]) + / temperature ) return expected - expected = get_expected_result( - logits, count_map, mask, frequency_penalties, presence_penalties - ) - - sampling_param = [ - SamplingParams( - presence_penalty=presence_penalties[0], - frequency_penalty=frequency_penalties[0], - ) - ] - sampling_state = get_sampling_state( - sampling_param, past_output_tokens=past_output_tokens, prompt_masks=prompt_masks - ) - new_logits = adjust_logits(logits, sampling_state, vocab_size) - assert torch.allclose(expected, new_logits) - - batch_size = 3 + vocab_size = 512 shape = (batch_size, vocab_size) logits = torch.rand(shape, dtype=dtype, device=dev) - presence_penalties = [0.8, 0.7, -0.8] - frequency_penalties = [-0.3, 2.0, 1.2] - past_output_tokens = [[2, 2, 2, 3, 5], [3, 1, 2, 4], [3, 3, 1]] - prompt_mask = torch.zeros((vocab_size,), dtype=torch.bool) - prompt_masks = [prompt_mask] * batch_size - - count_map, mask = prepare_metadata(past_output_tokens) - expected = get_expected_result( - logits, count_map, mask, frequency_penalties, presence_penalties - ) - - sampling_params = [ - SamplingParams( - presence_penalty=presence_penalties[i], - frequency_penalty=frequency_penalties[i], + past_output_tokens = [[2, 2, 2, 3, 5]] * batch_size + count_map, mask = _prepare_metadata(past_output_tokens, vocab_size) + + temperatures = [0.0, 0.6] + presence_penalties = [-2.0, 2.0] + frequency_penalties = [-2.0, 2.0] + repetition_penalties = [0.4, 1.0] + for batch_params in permutations( + product( + temperatures, + repetition_penalties, + presence_penalties, + frequency_penalties + ), + batch_size + ): + sampling_params = [ + SamplingParams( + temperature=temp, + repetition_penalty=rep_pen, + presence_penalty=pr_pen, + frequency_penalty=fr_pen, + vocab_size=vocab_size + ) + for temp, rep_pen, pr_pen, fr_pen in batch_params + ] + expected = _get_expected_result( + logits, + count_map, + mask, + [temp for temp, _, _, _ in batch_params], + [rep_pen for _, rep_pen, _, _ in batch_params], + [pr_pen for _, _, pr_pen, _ in batch_params], + [fr_pen for _, _, _, fr_pen in batch_params], ) - for i in range(batch_size) - ] - sampling_state = get_sampling_state( - sampling_params, - past_output_tokens=past_output_tokens, - prompt_masks=prompt_masks, - ) - new_logits = adjust_logits(logits, sampling_state, vocab_size) - assert torch.allclose(expected, new_logits) + sampling_state = get_sampling_state( + sampling_params, past_output_tokens=past_output_tokens, vocab_size=vocab_size + ) + new_logits = adjust_logits(logits, sampling_state, vocab_size) + assert torch.allclose(expected, new_logits) def test_top_p_top_k_checker(): - get_sampling_state([SamplingParams(top_p=0.8)]) - get_sampling_state([SamplingParams(top_k=3)]) + # top_p must be in (0, 1] + # top_k must be in (0, vocab_size] (use -1 to consider all tokens) + # top_p + get_sampling_state([SamplingParams(top_p=0.6)]) + get_sampling_state([SamplingParams(top_p=0.1)]) + get_sampling_state([SamplingParams(top_p=1.0)]) + + # top_k + get_sampling_state([SamplingParams(top_k=3)]) get_sampling_state([SamplingParams(top_k=-1)]) get_sampling_state([SamplingParams(top_k=1)]) - with pytest.raises(ValueError): - get_sampling_state([SamplingParams(top_p=0.0)]) + # combinations of top_p, top_k with valid values + get_sampling_state([SamplingParams(top_p=0.1, top_k=128)]) + get_sampling_state([SamplingParams(top_p=0.6, top_k=1)]) + get_sampling_state([SamplingParams(top_p=1.0, top_k=-1)]) + # combinations of top_p, top_k with invalid values with pytest.raises(ValueError): - get_sampling_state([SamplingParams(top_p=-0.8)]) - + get_sampling_state([SamplingParams(top_p=0.0, top_k=128)]) with pytest.raises(ValueError): - get_sampling_state([SamplingParams(top_k=0)]) - + get_sampling_state([SamplingParams(top_p=-1, top_k=-5)]) with pytest.raises(ValueError): - get_sampling_state([SamplingParams(top_k=0.8)]) + get_sampling_state([SamplingParams(top_p=5, top_k=0)]) + + # top_p, top_k with valid values in multi-batch + get_sampling_state( + [ + SamplingParams(top_p=0.1, top_k=128), + SamplingParams(top_p=0.5, top_k=1024), + SamplingParams(top_p=1.0, top_k=8), + ] + ) + get_sampling_state( + [SamplingParams(top_p=0.1), SamplingParams(top_p=0.5, top_k=1024), SamplingParams(top_k=8)] + ) + get_sampling_state( + [ + SamplingParams(top_p=1.0, top_k=-1), + SamplingParams(top_p=0.5, top_k=32000), + ] + ) + # top_p, top_k with invalid values in multi-batch with pytest.raises(ValueError): - get_sampling_state([SamplingParams(top_k=-2)]) - - -def test_top_p_top_k(): - def get_expected_result(logits, top_pks, filter_value=-float("Inf")): - """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering - Args: - logits: logits distribution shape (vocabulary size) - top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). - Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) - top_k >0: keep only top k tokens with highest probability (top-k filtering). - """ - batch_size = len(top_pks) - lst_logits = [] - for ii in range(batch_size): - _logits = logits[ii] - top_p, top_k = top_pks[ii] - if top_p > 0.0: - sorted_logits, sorted_indices = torch.sort(_logits, descending=True) - cumulative_probs = torch.cumsum( - torch.softmax(sorted_logits, dim=-1), dim=-1 - ) - - # Remove tokens with cumulative probability above the threshold - sorted_indices_to_remove = cumulative_probs > top_p - # Shift the indices to the right to keep also the first token above the threshold - sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[ - ..., :-1 - ].clone() - sorted_indices_to_remove[..., 0] = 0 - - indices_to_remove = sorted_indices[sorted_indices_to_remove] - _logits[indices_to_remove] = filter_value - - if top_k > 0: - # Remove all tokens with a probability less than the last token of the top-k - top_k_values = torch.topk(_logits, top_k)[0] - # Use `None` to insert a singleton dimension - # Equivalent to apply `squeeze` to the given dimension - # e.g., arr.shape = [3,3] - # arr[:,:,None].shape = [3,3,1] - indices_to_remove = _logits < top_k_values[..., -1, None] - _logits[indices_to_remove] = filter_value - - lst_logits.append(_logits) - return torch.stack(lst_logits) - - batch_size = 1 - top_p, top_k = 0.7, 5 - shape = (batch_size, vocab_size) - logits = torch.rand(shape, dtype=dtype, device=dev) - sampling_params = [ - SamplingParams(top_p=top_p, top_k=top_k) for _ in range(batch_size) - ] - sampling_state = get_sampling_state(sampling_params) - new_logits = adjust_logits(logits, sampling_state, vocab_size) - expected = logits.clone() - expected = get_expected_result(expected, top_pks=[(top_p, top_k)]) - assert torch.allclose(expected, new_logits) + get_sampling_state( + [ + SamplingParams(top_p=-1, top_k=128), + SamplingParams(top_p=0.5, top_k=12), + ] + ) + with pytest.raises(ValueError): + get_sampling_state( + [ + SamplingParams(top_p=0.1), + SamplingParams(top_k=-2), + ] + ) + with pytest.raises(ValueError): + get_sampling_state( + [ + SamplingParams(top_p=1.1, top_k=-1), + SamplingParams(top_p=0.5, top_k=64), + ] + ) - batch_size = 3 +def get_expected_result_by_top_pks(logits, top_pks, temps=None, filter_value=-float("Inf")): + """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering + Args: + logits: logits distribution shape (vocabulary size) + top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering). + Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751) + top_k >0: keep only top k tokens with highest probability (top-k filtering). + """ + batch_size = len(top_pks) + lst_logits = [] + if temps is None: + temps = [1.0] * batch_size + for ii in range(batch_size): + if temps[ii] < SAMPLING_EPS: + temps[ii] = 1.0 + _logits = logits[ii] / temps[ii] + top_p, top_k = top_pks[ii] + if top_p > 0.0: + sorted_logits, sorted_indices = torch.sort(_logits, descending=True) + cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) + + # Remove tokens with cumulative probability above the threshold + sorted_indices_to_remove = cumulative_probs > top_p + # Shift the indices to the right to keep also the first token above the threshold + sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() + sorted_indices_to_remove[..., 0] = 0 + + indices_to_remove = sorted_indices[sorted_indices_to_remove] + _logits[indices_to_remove] = filter_value + + if top_k > 0: + # Remove all tokens with a probability less than the last token of the top-k + top_k_values = torch.topk(_logits, top_k)[0] + # Use `None` to insert a singleton dimension + # Equivalent to apply `squeeze` to the given dimension + # e.g., arr.shape = [3,3] + # arr[:,:,None].shape = [3,3,1] + indices_to_remove = _logits < top_k_values[..., -1, None] + _logits[indices_to_remove] = filter_value + + lst_logits.append(_logits) + return torch.stack(lst_logits) + +@pytest.mark.parametrize("batch_size", [1, 4]) +def test_top_p_top_k(batch_size: int): + vocab_size = 32000 shape = (batch_size, vocab_size) logits = torch.rand(shape, dtype=dtype, device=dev) - top_pks = [(0.7, 3), (0.5, 2), (0.8, 5)] - sampling_params = [ - SamplingParams(top_p=top_p, top_k=top_k) for top_p, top_k in top_pks - ] - sampling_state = get_sampling_state(sampling_params) + for top_pks in permutations( + product( + [0.3, 0.7], # top_p + [128, 2048, 32000] # top_k + ), + batch_size + ): + sampling_params = [SamplingParams(top_p=top_p, top_k=top_k) for top_p, top_k in top_pks] + sampling_state = get_sampling_state(sampling_params) + new_logits = adjust_logits(logits, sampling_state, vocab_size) + expected = get_expected_result_by_top_pks(logits.clone(), top_pks) + assert torch.allclose(expected, new_logits) - new_logits = adjust_logits(logits, sampling_state, vocab_size) - expected = logits.clone() - expected = get_expected_result(expected, top_pks) - assert torch.allclose(expected, new_logits) +def test_logprobs_checker(): + get_sampling_state([SamplingParams(logprobs=False)]) + get_sampling_state([SamplingParams(logprobs=True)]) + get_sampling_state([SamplingParams(logprobs=True, top_logprobs=0)]) + get_sampling_state([SamplingParams(logprobs=True, top_logprobs=5)]) -def test_mixture_of_requests(): - # Mixed greedy & top_p/top_ks - batch_size = 6 + with pytest.raises(ValueError): + get_sampling_state([SamplingParams(logprobs=True, top_logprobs=-1)]) + + with pytest.raises(ValueError): + get_sampling_state([SamplingParams(logprobs=True, top_logprobs=6)]) + + with pytest.raises(TypeError): + get_sampling_state([SamplingParams(logprobs=True, top_logprobs=2.5)]) + +@pytest.mark.parametrize("batch_size", [1, 4, 8]) +def test_logprobs(batch_size: int): + vocab_size = 32000 shape = (batch_size, vocab_size) logits = torch.rand(shape, dtype=dtype, device=dev) - top_pks = [(0.7, 3), (1.0, -1), (1.0, -1), (0.5, 2), (1.0, -1), (0.8, 5)] - temps = [0.8, 0.8, 0.0, 0.0, 0.0, 0.7] - sampling_params = [ - SamplingParams(temperature=temps[i], top_p=top_p, top_k=top_k) - for i, (top_p, top_k) in enumerate(top_pks) - ] + + # No logprobs + sampling_params = [SamplingParams(logprobs=False) for _ in range(batch_size)] sampling_state = get_sampling_state(sampling_params) - new_logits = adjust_logits(logits, sampling_state, vocab_size) + output: SamplingOutput = sample(logits, sampling_state) + assert all([logprob_response is None for logprob_response in output.logprob_infos]) - # TODO(team): please follow-up. correctness check - # expected = logits.clone() - # expected = get_expected_result(expected, top_pks) - # assert torch.allclose(expected, new_logits) - - -if __name__ == "__main__": - test_temperature() - test_logit_bias_checker() - test_logit_bias() - test_penalties_checker() - test_penalties() - test_top_p_top_k_checker() - test_top_p_top_k() - test_mixture_of_requests() + # Logprob only of a current token + sampling_params = [SamplingParams(logprobs=True) for _ in range(batch_size)] + sampling_state = get_sampling_state(sampling_params) + output: SamplingOutput = sample(logits, sampling_state) + assert len(output.logprob_infos) == batch_size + for idx in range(batch_size): + assert isinstance(output.logprob_infos[idx].current_token_id, int) + assert isinstance(output.logprob_infos[idx].current_logprob, float) + assert output.logprob_infos[idx].top_token_ids.nelement() == 0 + assert output.logprob_infos[idx].top_logprobs.nelement() == 0 + + # Top-k logprobs + for top_logprobs in [1, 3, 5]: + sampling_params = [ + SamplingParams(logprobs=True, top_logprobs=top_logprobs) for _ in range(batch_size) + ] + sampling_state = get_sampling_state(sampling_params) + output: SamplingOutput = sample(logits, sampling_state) + assert len(output.logprob_infos) == batch_size + for idx in range(batch_size): + assert isinstance(output.logprob_infos[idx].current_token_id, int) + assert isinstance(output.logprob_infos[idx].current_logprob, float) + assert output.logprob_infos[idx].top_token_ids.nelement() != 0 + assert len(output.logprob_infos[idx].top_token_ids) == top_logprobs + assert output.logprob_infos[idx].top_logprobs.nelement() != 0 + assert len(output.logprob_infos[idx].top_logprobs) == top_logprobs + +@pytest.mark.skip(reason=""" + This test is currently broken. Need to validate correctness of this check + and make sure that _apply_top_p_top_k from sampler.py does not produce too many -inf values + """) +@pytest.mark.parametrize("batch_size", [1, 4, 8, 12]) +def test_mixture_of_requests(batch_size: int): + # Mixed temperature & top_p/top_ks + vocab_size = 32000 + top_ps = list(torch.arange(1, 0, -0.01)) + top_ks = list(range(1, vocab_size + 1)) + temperatures = list(torch.arange(0, 2.1, 0.1)) + temp_weights = [0.5] + temp_weights.extend([1 / (len(temperatures) - 1)] * (len(temperatures) - 1)) + top_ks.append(-1) + for _ in range(10): + shape = (batch_size, vocab_size) + logits = torch.rand(shape, dtype=dtype, device=dev) + top_pks = [(random.choice(top_ps), random.choice(top_ks)) for _ in range(batch_size)] + temps = random.choices(temperatures, weights=temp_weights, k=batch_size) + sampling_params = [ + SamplingParams(temperature=temps[i], top_p=top_p, top_k=top_k) + for i, (top_p, top_k) in enumerate(top_pks) + ] + sampling_state = get_sampling_state(sampling_params) + new_logits = adjust_logits(logits, sampling_state, vocab_size) + expected = get_expected_result_by_top_pks(logits.clone(), top_pks, temps) + assert torch.allclose(expected, new_logits) \ No newline at end of file