From 381dd57bd69f027a3298d107d8eb851c3c29d8e4 Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Wed, 28 Aug 2024 18:58:52 -0700 Subject: [PATCH] Sampler cudagraph (#1253) --- python/sglang/bench_latency.py | 14 ++-- python/sglang/srt/layers/logits_processor.py | 8 +- python/sglang/srt/layers/sampler.py | 83 +++++++++++++++---- python/sglang/srt/managers/schedule_batch.py | 28 +++++-- python/sglang/srt/managers/tp_worker.py | 52 +++++++----- .../srt/model_executor/cuda_graph_runner.py | 33 ++++++-- .../srt/model_executor/forward_batch_info.py | 7 ++ .../sglang/srt/model_executor/model_runner.py | 14 +++- python/sglang/srt/models/chatglm.py | 16 +--- python/sglang/srt/models/commandr.py | 6 +- python/sglang/srt/models/dbrx.py | 6 +- python/sglang/srt/models/deepseek.py | 6 +- python/sglang/srt/models/deepseek_v2.py | 6 +- python/sglang/srt/models/gemma.py | 6 +- python/sglang/srt/models/gemma2.py | 6 +- python/sglang/srt/models/gpt_bigcode.py | 6 +- python/sglang/srt/models/grok.py | 6 +- python/sglang/srt/models/internlm2.py | 6 +- python/sglang/srt/models/llama2.py | 10 ++- .../sglang/srt/models/llama_classification.py | 4 +- python/sglang/srt/models/minicpm.py | 6 +- python/sglang/srt/models/mixtral.py | 6 +- python/sglang/srt/models/mixtral_quant.py | 6 +- python/sglang/srt/models/qwen.py | 7 +- python/sglang/srt/models/qwen2.py | 8 +- python/sglang/srt/models/qwen2_moe.py | 19 ++--- python/sglang/srt/models/stablelm.py | 6 +- .../srt/sampling/sampling_batch_info.py | 75 ++++++++++++++++- python/sglang/test/runners.py | 2 +- 29 files changed, 342 insertions(+), 116 deletions(-) diff --git a/python/sglang/bench_latency.py b/python/sglang/bench_latency.py index dea910f5772..3a487408573 100644 --- a/python/sglang/bench_latency.py +++ b/python/sglang/bench_latency.py @@ -200,16 +200,16 @@ def extend(reqs, model_runner): tree_cache=None, ) batch.prepare_for_extend(model_runner.model_config.vocab_size) - output = model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids = batch.sample(output.next_token_logits) - return next_token_ids, output.next_token_logits, batch + sample_output, logits_output = model_runner.forward(batch, ForwardMode.EXTEND) + next_token_ids = sample_output.batch_next_token_ids.tolist() + return next_token_ids, logits_output.next_token_logits, batch def decode(input_token_ids, batch, model_runner): - batch.prepare_for_decode(input_token_ids.cpu().numpy()) - output = model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids = batch.sample(output.next_token_logits) - return next_token_ids, output.next_token_logits + batch.prepare_for_decode(input_token_ids) + sample_output, logits_output = model_runner.forward(batch, ForwardMode.DECODE) + next_token_ids = sample_output.batch_next_token_ids.tolist() + return next_token_ids, logits_output.next_token_logits @torch.inference_mode() diff --git a/python/sglang/srt/layers/logits_processor.py b/python/sglang/srt/layers/logits_processor.py index 63f74d8b026..b81f3d2a040 100644 --- a/python/sglang/srt/layers/logits_processor.py +++ b/python/sglang/srt/layers/logits_processor.py @@ -29,7 +29,7 @@ @dataclasses.dataclass -class LogitProcessorOutput: +class LogitsProcessorOutput: # The logits of the next tokens. shape: [#seq, vocab_size] next_token_logits: torch.Tensor # The logprobs of the next tokens. shape: [#seq, vocab_size] @@ -185,7 +185,7 @@ def forward( # Return only last_logits if logprob is not requested if not logits_metadata.return_logprob: - return LogitProcessorOutput( + return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=None, normalized_prompt_logprobs=None, @@ -209,7 +209,7 @@ def forward( else: output_top_logprobs = None - return LogitProcessorOutput( + return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, normalized_prompt_logprobs=None, @@ -278,7 +278,7 @@ def forward( # Remove the last token logprob for the prefill tokens. input_token_logprobs = input_token_logprobs[:-1] - return LogitProcessorOutput( + return LogitsProcessorOutput( next_token_logits=last_logits, next_token_logprobs=last_logprobs, normalized_prompt_logprobs=normalized_prompt_logprobs, diff --git a/python/sglang/srt/layers/sampler.py b/python/sglang/srt/layers/sampler.py index 3006e765c88..6cb7d0a7c11 100644 --- a/python/sglang/srt/layers/sampler.py +++ b/python/sglang/srt/layers/sampler.py @@ -1,4 +1,6 @@ +import dataclasses import logging +from typing import Union import torch from flashinfer.sampling import ( @@ -9,6 +11,8 @@ ) from vllm.model_executor.custom_op import CustomOp +from sglang.srt.layers.logits_processor import LogitsProcessorOutput + # TODO: move this dict to another place from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo @@ -16,30 +20,71 @@ logger = logging.getLogger(__name__) +@dataclasses.dataclass +class SampleOutput: + success: torch.Tensor + probs: torch.Tensor + batch_next_token_ids: torch.Tensor + + class Sampler(CustomOp): def __init__(self): super().__init__() - def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): + def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): + # min-token, presence, frequency + if sampling_info.linear_penalties is not None: + logits += sampling_info.linear_penalties + + # repetition + if sampling_info.scaling_penalties is not None: + logits = torch.where( + logits > 0, + logits / sampling_info.scaling_penalties, + logits * sampling_info.scaling_penalties, + ) + + return logits + + def _get_probs( + self, + logits: torch.Tensor, + sampling_info: SamplingBatchInfo, + is_torch_compile: bool = False, + ): # Post process logits logits = logits.contiguous() logits.div_(sampling_info.temperatures) + if is_torch_compile: + # FIXME: Temporary workaround for unknown bugs in torch.compile + logits.add_(0) + if sampling_info.logit_bias is not None: logits.add_(sampling_info.logit_bias) if sampling_info.vocab_mask is not None: logits = logits.masked_fill(~sampling_info.vocab_mask, float("-inf")) - logits = sampling_info.penalizer_orchestrator.apply(logits) + logits = self._apply_penalties(logits, sampling_info) - probs = torch.softmax(logits, dim=-1) + return torch.softmax(logits, dim=-1) + + def forward_cuda( + self, + logits: Union[torch.Tensor, LogitsProcessorOutput], + sampling_info: SamplingBatchInfo, + ): + if isinstance(logits, LogitsProcessorOutput): + logits = logits.next_token_logits + + probs = self._get_probs(logits, sampling_info) if not global_server_args_dict["disable_flashinfer_sampling"]: max_top_k_round, batch_size = 32, probs.shape[0] uniform_samples = torch.rand( (max_top_k_round, batch_size), device=probs.device ) - if sampling_info.min_ps.any(): + if sampling_info.need_min_p_sampling: probs = top_k_renorm_prob(probs, sampling_info.top_ks) probs = top_p_renorm_prob(probs, sampling_info.top_ps) batch_next_token_ids, success = min_p_sampling_from_probs( @@ -55,18 +100,23 @@ def forward_cuda(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps ) - if not torch.all(success): - logging.warning("Sampling failed, fallback to top_k=1 strategy") - probs = probs.masked_fill(torch.isnan(probs), 0.0) - argmax_ids = torch.argmax(probs, dim=-1) - batch_next_token_ids = torch.where( - success, batch_next_token_ids, argmax_ids - ) + return SampleOutput(success, probs, batch_next_token_ids) - return batch_next_token_ids + def forward_native( + self, + logits: Union[torch.Tensor, LogitsProcessorOutput], + sampling_info: SamplingBatchInfo, + ): + if isinstance(logits, LogitsProcessorOutput): + logits = logits.next_token_logits + + probs = self._get_probs(logits, sampling_info, is_torch_compile=True) + + batch_next_token_ids, success = top_k_top_p_min_p_sampling_from_probs_torch( + probs, sampling_info.top_ks, sampling_info.top_ps, sampling_info.min_ps + ) - def forward_native(): - raise NotImplementedError("Native forward is not implemented yet.") + return SampleOutput(success, probs, batch_next_token_ids) def top_k_top_p_min_p_sampling_from_probs_torch( @@ -87,7 +137,10 @@ def top_k_top_p_min_p_sampling_from_probs_torch( probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0 probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0]) try: - sampled_index = torch.multinomial(probs_sort, num_samples=1) + # FIXME: torch.multiomial does not support num_samples = 1 + sampled_index = torch.multinomial(probs_sort, num_samples=2, replacement=True)[ + :, :1 + ] except RuntimeError as e: logger.warning(f"Sampling error: {e}") batch_next_token_ids = torch.zeros( diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 5554170a350..f5b9c9eb27d 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -17,7 +19,7 @@ import logging from dataclasses import dataclass -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union import torch @@ -29,6 +31,10 @@ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo +if TYPE_CHECKING: + from sglang.srt.layers.sampler import SampleOutput + + INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 # Put some global args for easy access @@ -678,11 +684,17 @@ def merge(self, other: "ScheduleBatch"): self.top_logprobs_nums.extend(other.top_logprobs_nums) self.return_logprob = any(req.return_logprob for req in self.reqs) - def sample(self, logits: torch.Tensor): - from sglang.srt.layers.sampler import Sampler - - sampler = Sampler() - - batch_next_token_ids = sampler(logits, self.sampling_info) + def check_sample_results(self, sample_output: SampleOutput): + if not torch.all(sample_output.success): + probs = sample_output.probs + batch_next_token_ids = sample_output.batch_next_token_ids + logging.warning("Sampling failed, fallback to top_k=1 strategy") + probs = probs.masked_fill(torch.isnan(probs), 0.0) + argmax_ids = torch.argmax(probs, dim=-1) + batch_next_token_ids = torch.where( + sample_output.success, batch_next_token_ids, argmax_ids + ) + sample_output.probs = probs + sample_output.batch_next_token_ids = batch_next_token_ids - return batch_next_token_ids + return sample_output.batch_next_token_ids diff --git a/python/sglang/srt/managers/tp_worker.py b/python/sglang/srt/managers/tp_worker.py index cd1b580643c..123b1f5d5dc 100644 --- a/python/sglang/srt/managers/tp_worker.py +++ b/python/sglang/srt/managers/tp_worker.py @@ -31,7 +31,7 @@ from sglang.srt.constrained.fsm_cache import FSMCache from sglang.srt.constrained.jump_forward import JumpForwardCache from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer -from sglang.srt.layers.logits_processor import LogitProcessorOutput +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.managers.io_struct import ( AbortReq, BatchEmbeddingOut, @@ -504,21 +504,29 @@ def forward_prefill_batch(self, batch: ScheduleBatch): if self.model_runner.is_generation: # Forward and sample the next tokens if batch.extend_num_tokens != 0: - output = self.model_runner.forward(batch, ForwardMode.EXTEND) - next_token_ids = batch.sample(output.next_token_logits) + sample_output, logits_output = self.model_runner.forward( + batch, ForwardMode.EXTEND + ) + next_token_ids = batch.check_sample_results(sample_output) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids ) # Move logprobs to cpu - if output.next_token_logprobs is not None: - output.next_token_logprobs = output.next_token_logprobs[ - torch.arange(len(next_token_ids), device=next_token_ids.device), - next_token_ids, - ].tolist() - output.input_token_logprobs = output.input_token_logprobs.tolist() - output.normalized_prompt_logprobs = ( - output.normalized_prompt_logprobs.tolist() + if logits_output.next_token_logprobs is not None: + logits_output.next_token_logprobs = ( + logits_output.next_token_logprobs[ + torch.arange( + len(next_token_ids), device=next_token_ids.device + ), + next_token_ids, + ].tolist() + ) + logits_output.input_token_logprobs = ( + logits_output.input_token_logprobs.tolist() + ) + logits_output.normalized_prompt_logprobs = ( + logits_output.normalized_prompt_logprobs.tolist() ) next_token_ids = next_token_ids.tolist() @@ -557,12 +565,14 @@ def forward_prefill_batch(self, batch: ScheduleBatch): self.req_to_token_pool.free(req.req_pool_idx) if req.return_logprob: - self.add_logprob_return_values(i, req, pt, next_token_ids, output) + self.add_logprob_return_values( + i, req, pt, next_token_ids, logits_output + ) pt += req.extend_input_len else: assert batch.extend_num_tokens != 0 - output = self.model_runner.forward(batch, ForwardMode.EXTEND) - embeddings = output.embeddings.tolist() + logits_output = self.model_runner.forward(batch, ForwardMode.EXTEND) + embeddings = logits_output.embeddings.tolist() # Check finish conditions for i, req in enumerate(batch.reqs): @@ -590,7 +600,7 @@ def add_logprob_return_values( req: Req, pt: int, next_token_ids: List[int], - output: LogitProcessorOutput, + output: LogitsProcessorOutput, ): if req.normalized_prompt_logprob is None: req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i] @@ -672,15 +682,17 @@ def forward_decode_batch(self, batch: ScheduleBatch): batch.prepare_for_decode() # Forward and sample the next tokens - output = self.model_runner.forward(batch, ForwardMode.DECODE) - next_token_ids = batch.sample(output.next_token_logits) + sample_output, logits_output = self.model_runner.forward( + batch, ForwardMode.DECODE + ) + next_token_ids = batch.check_sample_results(sample_output) batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( next_token_ids ) # Move logprobs to cpu - if output.next_token_logprobs is not None: - next_token_logprobs = output.next_token_logprobs[ + if logits_output.next_token_logprobs is not None: + next_token_logprobs = logits_output.next_token_logprobs[ torch.arange(len(next_token_ids), device=next_token_ids.device), next_token_ids, ].tolist() @@ -706,7 +718,7 @@ def forward_decode_batch(self, batch: ScheduleBatch): (next_token_logprobs[i], next_token_id) ) if req.top_logprobs_num > 0: - req.output_top_logprobs.append(output.output_top_logprobs[i]) + req.output_top_logprobs.append(logits_output.output_top_logprobs[i]) self.handle_finished_requests(batch) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 796db26623f..40c87af88cf 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -26,16 +26,18 @@ from vllm.model_executor.custom_op import CustomOp from sglang.srt.layers.logits_processor import ( - LogitProcessorOutput, LogitsMetadata, LogitsProcessor, + LogitsProcessorOutput, ) +from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.model_executor.forward_batch_info import ( ForwardMode, InputMetadata, update_flashinfer_indices, ) +from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.utils import monkey_patch_vllm_all_gather @@ -144,6 +146,10 @@ def __init__( self.flashinfer_kv_indices.clone(), ] + # Sampling inputs + vocab_size = model_runner.model_config.vocab_size + self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size) + self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else [] if use_torch_compile: @@ -235,6 +241,7 @@ def capture_one_batch_size(self, bs: int, forward: Callable): def run_once(): input_metadata = InputMetadata( forward_mode=ForwardMode.DECODE, + sampling_info=self.sampling_info[:bs], batch_size=bs, req_pool_indices=req_pool_indices, seq_lens=seq_lens, @@ -299,27 +306,35 @@ def replay(self, batch: ScheduleBatch): self.flashinfer_handlers[bs], ) + # Sampling inputs + self.sampling_info.inplace_assign(raw_bs, batch.sampling_info) + # Replay torch.cuda.synchronize() self.graphs[bs].replay() torch.cuda.synchronize() - output = self.output_buffers[bs] + sample_output, logits_output = self.output_buffers[bs] # Unpad if bs != raw_bs: - output = LogitProcessorOutput( - next_token_logits=output.next_token_logits[:raw_bs], + logits_output = LogitsProcessorOutput( + next_token_logits=logits_output.next_token_logits[:raw_bs], next_token_logprobs=None, normalized_prompt_logprobs=None, input_token_logprobs=None, input_top_logprobs=None, output_top_logprobs=None, ) + sample_output = SampleOutput( + sample_output.success[:raw_bs], + sample_output.probs[:raw_bs], + sample_output.batch_next_token_ids[:raw_bs], + ) # Extract logprobs if batch.return_logprob: - output.next_token_logprobs = torch.nn.functional.log_softmax( - output.next_token_logits, dim=-1 + logits_output.next_token_logprobs = torch.nn.functional.log_softmax( + logits_output.next_token_logits, dim=-1 ) return_top_logprob = any(x > 0 for x in batch.top_logprobs_nums) if return_top_logprob: @@ -327,8 +342,8 @@ def replay(self, batch: ScheduleBatch): forward_mode=ForwardMode.DECODE, top_logprobs_nums=batch.top_logprobs_nums, ) - output.output_top_logprobs = LogitsProcessor.get_top_logprobs( - output.next_token_logprobs, logits_metadata + logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs( + logits_output.next_token_logprobs, logits_metadata )[1] - return output + return sample_output, logits_output diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index f24cdf6b723..3d40c9d7558 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -1,3 +1,5 @@ +from __future__ import annotations + """ Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); @@ -26,6 +28,7 @@ if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner + from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo class ForwardMode(IntEnum): @@ -42,6 +45,7 @@ class InputMetadata: """Store all inforamtion of a forward pass.""" forward_mode: ForwardMode + sampling_info: SamplingBatchInfo batch_size: int req_pool_indices: torch.Tensor seq_lens: torch.Tensor @@ -169,6 +173,7 @@ def from_schedule_batch( ): ret = cls( forward_mode=forward_mode, + sampling_info=batch.sampling_info, batch_size=batch.batch_size(), req_pool_indices=batch.req_pool_indices, seq_lens=batch.seq_lens, @@ -179,6 +184,8 @@ def from_schedule_batch( top_logprobs_nums=batch.top_logprobs_nums, ) + ret.sampling_info.prepare_penalties() + ret.compute_positions(batch) ret.compute_extend_infos(batch) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 8ef47a530f5..e6f5e743110 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -21,7 +21,7 @@ import logging import pkgutil from functools import lru_cache -from typing import Optional, Type +from typing import Optional, Tuple, Type import torch import torch.nn as nn @@ -44,6 +44,8 @@ from vllm.model_executor.models import ModelRegistry from sglang.global_config import global_config +from sglang.srt.layers.logits_processor import LogitsProcessorOutput +from sglang.srt.layers.sampler import SampleOutput from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.mem_cache.memory_pool import ( MHATokenToKVPool, @@ -524,7 +526,11 @@ def init_cuda_graphs(self): @torch.inference_mode() def forward_decode(self, batch: ScheduleBatch): - if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)): + if ( + self.cuda_graph_runner + and self.cuda_graph_runner.can_run(len(batch.reqs)) + and not batch.sampling_info.has_bias() + ): return self.cuda_graph_runner.replay(batch) input_metadata = InputMetadata.from_schedule_batch( @@ -573,7 +579,9 @@ def forward_extend_multi_modal(self, batch: ScheduleBatch): input_metadata.image_offsets, ) - def forward(self, batch: ScheduleBatch, forward_mode: ForwardMode): + def forward( + self, batch: ScheduleBatch, forward_mode: ForwardMode + ) -> Tuple[SampleOutput, LogitsProcessorOutput]: if self.is_multimodal_model and forward_mode == ForwardMode.EXTEND: return self.forward_extend_multi_modal(batch) elif forward_mode == ForwardMode.DECODE: diff --git a/python/sglang/srt/models/chatglm.py b/python/sglang/srt/models/chatglm.py index b38b62fafd3..9eb04dc263d 100644 --- a/python/sglang/srt/models/chatglm.py +++ b/python/sglang/srt/models/chatglm.py @@ -31,20 +31,18 @@ ) from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.sequence import SamplerOutput from vllm.transformers_utils.configs import ChatGLMConfig from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata LoraConfig = None @@ -383,17 +381,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/python/sglang/srt/models/commandr.py b/python/sglang/srt/models/commandr.py index f6d6f6e1f94..c360106f97c 100644 --- a/python/sglang/srt/models/commandr.py +++ b/python/sglang/srt/models/commandr.py @@ -64,6 +64,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -326,6 +327,7 @@ def __init__( self.config = config self.quant_config = quant_config self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() self.model = CohereModel(config, quant_config) @torch.no_grad() @@ -340,9 +342,11 @@ def forward( positions, input_metadata, ) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/dbrx.py b/python/sglang/srt/models/dbrx.py index 39ac4aefa72..b3a76b56ae2 100644 --- a/python/sglang/srt/models/dbrx.py +++ b/python/sglang/srt/models/dbrx.py @@ -45,6 +45,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -382,6 +383,7 @@ def __init__( padding_size=DEFAULT_VOCAB_PADDING_SIZE, ) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -391,9 +393,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [ diff --git a/python/sglang/srt/models/deepseek.py b/python/sglang/srt/models/deepseek.py index 59fd1ec7ed8..b939602c1ba 100644 --- a/python/sglang/srt/models/deepseek.py +++ b/python/sglang/srt/models/deepseek.py @@ -46,6 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -385,6 +386,7 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -394,9 +396,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 13dd477392e..15ecf4bb66b 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -45,6 +45,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -632,6 +633,7 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() def forward( self, @@ -640,9 +642,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/gemma.py b/python/sglang/srt/models/gemma.py index ae3b1b1948c..5a6e5df37fe 100644 --- a/python/sglang/srt/models/gemma.py +++ b/python/sglang/srt/models/gemma.py @@ -37,6 +37,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -287,6 +288,7 @@ def __init__( self.quant_config = quant_config self.model = GemmaModel(config, quant_config=quant_config) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -297,9 +299,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return (sample_output, logits_output) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/gemma2.py b/python/sglang/srt/models/gemma2.py index 3223424d79c..77ebd8564c6 100644 --- a/python/sglang/srt/models/gemma2.py +++ b/python/sglang/srt/models/gemma2.py @@ -37,6 +37,7 @@ from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -346,6 +347,7 @@ def __init__( self.quant_config = quant_config self.model = Gemma2Model(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -356,9 +358,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def get_attention_sliding_window_size(self): return get_attention_sliding_window_size(self.config) diff --git a/python/sglang/srt/models/gpt_bigcode.py b/python/sglang/srt/models/gpt_bigcode.py index 94b7f6153cf..dc828f0142e 100644 --- a/python/sglang/srt/models/gpt_bigcode.py +++ b/python/sglang/srt/models/gpt_bigcode.py @@ -35,6 +35,7 @@ from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -261,6 +262,7 @@ def __init__( if lora_config: self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -270,9 +272,11 @@ def forward( input_metadata: InputMetadata, ) -> torch.Tensor: hidden_states = self.transformer(input_ids, positions, input_metadata) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) diff --git a/python/sglang/srt/models/grok.py b/python/sglang/srt/models/grok.py index daf6f25da13..3c2a2c65eae 100644 --- a/python/sglang/srt/models/grok.py +++ b/python/sglang/srt/models/grok.py @@ -46,6 +46,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -297,6 +298,7 @@ def __init__( self.model = Grok1Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() # Monkey patch _prepare_weights to load pre-sharded weights setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) @@ -313,9 +315,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/internlm2.py b/python/sglang/srt/models/internlm2.py index f2947e991b5..c0e4d19e128 100644 --- a/python/sglang/srt/models/internlm2.py +++ b/python/sglang/srt/models/internlm2.py @@ -40,6 +40,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -262,6 +263,7 @@ def __init__( self.model = InternLM2Model(config, quant_config) self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -272,9 +274,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.output.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/llama2.py b/python/sglang/srt/models/llama2.py index fe75916a43b..22751d9b674 100644 --- a/python/sglang/srt/models/llama2.py +++ b/python/sglang/srt/models/llama2.py @@ -39,8 +39,9 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor +from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -302,6 +303,7 @@ def __init__( self.model = LlamaModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -310,11 +312,13 @@ def forward( positions: torch.Tensor, input_metadata: InputMetadata, input_embeds: torch.Tensor = None, - ) -> LogitProcessorOutput: + ) -> LogitsProcessorOutput: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def get_module_name(self, name): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/llama_classification.py b/python/sglang/srt/models/llama_classification.py index c5effbfc9c6..03ab5e802cf 100644 --- a/python/sglang/srt/models/llama_classification.py +++ b/python/sglang/srt/models/llama_classification.py @@ -24,7 +24,7 @@ from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from sglang.srt.layers.logits_processor import LogitProcessorOutput +from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.models.llama2 import LlamaModel @@ -65,7 +65,7 @@ def forward( (input_metadata.batch_size, self.config.classification_out_size) ).to(input_ids.device) - return LogitProcessorOutput( + return LogitsProcessorOutput( next_token_logits=scores, next_token_logprobs=scores, normalized_prompt_logprobs=scores, diff --git a/python/sglang/srt/models/minicpm.py b/python/sglang/srt/models/minicpm.py index 49ff1926f39..0028ae67a8c 100644 --- a/python/sglang/srt/models/minicpm.py +++ b/python/sglang/srt/models/minicpm.py @@ -39,6 +39,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -297,6 +298,7 @@ def __init__( self.scale_width = self.config.hidden_size / self.config.dim_model_base self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -314,9 +316,11 @@ def forward( lm_head_weight = self.model.embed_tokens.weight else: lm_head_weight = self.lm_head.weight - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, lm_head_weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/mixtral.py b/python/sglang/srt/models/mixtral.py index d11f6c95198..ca38cb03bae 100644 --- a/python/sglang/srt/models/mixtral.py +++ b/python/sglang/srt/models/mixtral.py @@ -41,6 +41,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -299,6 +300,7 @@ def __init__( self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() def forward( self, @@ -308,9 +310,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/mixtral_quant.py b/python/sglang/srt/models/mixtral_quant.py index b02e925c5a0..97ac09ee629 100644 --- a/python/sglang/srt/models/mixtral_quant.py +++ b/python/sglang/srt/models/mixtral_quant.py @@ -45,6 +45,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -333,6 +334,7 @@ def __init__( self.model = MixtralModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -343,9 +345,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen.py b/python/sglang/srt/models/qwen.py index 93dae9585c3..4958a812985 100644 --- a/python/sglang/srt/models/qwen.py +++ b/python/sglang/srt/models/qwen.py @@ -39,6 +39,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -251,6 +252,7 @@ def __init__( vocab_size = ((config.vocab_size + 63) // 64) * 64 self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -260,10 +262,11 @@ def forward( input_metadata: InputMetadata, ): hidden_states = self.transformer(input_ids, positions, input_metadata) - next_tokens = self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - return next_tokens + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index a0c54f69105..6bb5c0b9066 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -38,8 +38,9 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor -from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType +from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata Qwen2Config = None @@ -276,6 +277,7 @@ def __init__( self.model = Qwen2Model(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) @torch.no_grad() @@ -289,9 +291,11 @@ def forward( ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) if not get_embedding: - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output else: return self.pooler(hidden_states, input_metadata) diff --git a/python/sglang/srt/models/qwen2_moe.py b/python/sglang/srt/models/qwen2_moe.py index d5c79a40f0e..67b5a6ce663 100644 --- a/python/sglang/srt/models/qwen2_moe.py +++ b/python/sglang/srt/models/qwen2_moe.py @@ -35,10 +35,8 @@ ReplicatedLinear, RowParallelLinear, ) -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope -from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, @@ -49,6 +47,7 @@ from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -366,6 +365,7 @@ def __init__( config.vocab_size, config.hidden_size, quant_config=quant_config ) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -376,20 +376,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) - - def compute_logits( - self, - input_ids: torch.Tensor, - hidden_states: torch.Tensor, - input_metadata: InputMetadata, - ) -> torch.Tensor: - logits = self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, input_metadata - ) - return logits + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/models/stablelm.py b/python/sglang/srt/models/stablelm.py index 9e10f12f2a2..a3102baabd4 100644 --- a/python/sglang/srt/models/stablelm.py +++ b/python/sglang/srt/models/stablelm.py @@ -40,6 +40,7 @@ from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.radix_attention import RadixAttention +from sglang.srt.layers.sampler import Sampler from sglang.srt.model_executor.forward_batch_info import InputMetadata @@ -249,6 +250,7 @@ def __init__( self.model = StableLMEpochModel(config, quant_config=quant_config) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.logits_processor = LogitsProcessor(config) + self.sampler = Sampler() @torch.no_grad() def forward( @@ -259,9 +261,11 @@ def forward( input_embeds: torch.Tensor = None, ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) - return self.logits_processor( + logits_output = self.logits_processor( input_ids, hidden_states, self.lm_head.weight, input_metadata ) + sample_output = self.sampler(logits_output, input_metadata.sampling_info) + return sample_output, logits_output def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ diff --git a/python/sglang/srt/sampling/sampling_batch_info.py b/python/sglang/srt/sampling/sampling_batch_info.py index bc70a9018ed..7843f4bd32d 100644 --- a/python/sglang/srt/sampling/sampling_batch_info.py +++ b/python/sglang/srt/sampling/sampling_batch_info.py @@ -21,10 +21,63 @@ class SamplingBatchInfo: top_ps: torch.Tensor = None top_ks: torch.Tensor = None min_ps: torch.Tensor = None - penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None + + # Dispatch in CUDA graph + need_min_p_sampling: bool = False + + # Bias Tensors logit_bias: torch.Tensor = None vocab_mask: torch.Tensor = None + # Penalizer + penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None + linear_penalties: torch.Tensor = None + scaling_penalties: torch.Tensor = None + + def has_bias(self): + return ( + self.logit_bias is not None + or self.vocab_mask is not None + or self.linear_penalties is not None + or self.scaling_penalties is not None + ) + + @classmethod + def dummy_one(cls, max_bs: int, vocab_size: int): + ret = cls(vocab_size=vocab_size) + ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda") + ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda") + ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda") + ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda") + return ret + + def __getitem__(self, key): + if isinstance(key, slice): + # NOTE: We do not use cuda graph when there is bias tensors + assert not self.has_bias() + return SamplingBatchInfo( + vocab_size=self.vocab_size, + temperatures=self.temperatures[key], + top_ps=self.top_ps[key], + top_ks=self.top_ks[key], + min_ps=self.min_ps[key], + need_min_p_sampling=self.need_min_p_sampling, + ) + else: + raise NotImplementedError + + def inplace_assign(self, bs: int, other: SamplingBatchInfo): + # NOTE: We do not use cuda graph when there is bias tensors + assert not self.has_bias() + + self.vocab_size = other.vocab_size + self.need_min_p_sampling = other.need_min_p_sampling + + self.temperatures[:bs] = other.temperatures + self.top_ps[:bs] = other.top_ps + self.top_ks[:bs] = other.top_ks + self.min_ps[:bs] = other.min_ps + @classmethod def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): device = "cuda" @@ -45,6 +98,7 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): ret.min_ps = torch.tensor( [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device ) + ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs) # Each penalizers will do nothing if they evaluate themselves as not required by looking at # the sampling_params of the requests (See {_is_required()} of each penalizers). So this @@ -72,6 +126,25 @@ def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): return ret + def prepare_penalties(self): + self.scaling_penalties = None + self.linear_penalties = None + + for penalizer in self.penalizer_orchestrator.penalizers.values(): + if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer): + if penalizer.is_prepared(): + self.scaling_penalties = penalizer.cumulated_repetition_penalties + else: + if penalizer.is_prepared(): + if self.linear_penalties is None: + bs = self.penalizer_orchestrator.batch.batch_size() + self.linear_penalties = torch.zeros( + (bs, self.vocab_size), + dtype=torch.float32, + device="cuda", + ) + self.linear_penalties = penalizer.apply(self.linear_penalties) + def update_regex_vocab_mask(self, batch: ScheduleBatch): bs, reqs = batch.batch_size(), batch.reqs device = "cuda" diff --git a/python/sglang/test/runners.py b/python/sglang/test/runners.py index e69d699a7d3..ac69ab875b9 100644 --- a/python/sglang/test/runners.py +++ b/python/sglang/test/runners.py @@ -180,7 +180,7 @@ def __init__( tp_size=tp_size, dtype=get_dtype_str(torch_dtype), port=port, - mem_fraction_static=0.7, + mem_fraction_static=0.69, trust_remote_code=False, is_embedding=not self.is_generation, )