-
Notifications
You must be signed in to change notification settings - Fork 974
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature] Add sampler custom logits processor #2396
[Feature] Add sampler custom logits processor #2396
Conversation
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Thanks for taking this. Can you add some end-to-end tests and examples? |
@merrymercy Thanks for taking a look on this. I am still trying to understand the appropriate layer for the user to register their If the above seems correct, I will try to get it reviewable within this week. |
Your proposal sounds good. We should be able to register this function with all interfaces. |
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
…m/hongpeng-guo/sglang into hpguo/add_sampler_logit_processor
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are a few unit test on performance and accuracy being a bit flaky. they seem not to be directly related to this PR, though.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are almost there! Some final comments.
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Signed-off-by: Hongpeng Guo <[email protected]>
Thanks a lot for shepherding this! Just handled the comments. PTAL |
@@ -76,15 +88,48 @@ def from_schedule_batch( | |||
[r.sampling_params.min_p for r in reqs], dtype=torch.float | |||
).to(device, non_blocking=True) | |||
|
|||
# Check if any request has custom logit processor | |||
has_custom_logit_processor = any(r.custom_logit_processor for r in reqs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can even skip this for loop if custom logit process is not enabled by the server.
self, unfinished_indices: List[int], new_indices: torch.Tensor | ||
): | ||
"""Filter the custom logit processor and custom params""" | ||
if not self.custom_logit_processor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this seems not needed. If self.custom_logit_processor
is None, why will it go to this function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed, I can remove this or make it a simple assert.
@hongpeng-guo Thanks. It is merged. Some follow-up items:
|
Motivation
This PR tries to support custom logits processor registered by users, so users can easily implement their custom sampling methods without the need to change the sglang code.
Related Issue
#2291
Modifications
CustomLogitProcessor
as insglang/srt/sampling/custom_logit_processor.py
. This interface contains a (1) Callable function to process the logtis, and (2)from_str
andto_str
methods that any user defined subclass can be serialized and passed into the server usingrequest.post(url, json=data)
to the/generate
endpoint.custom_params
as a field ofSamplingParams
custom_logit_processor
as a field toGenerateReqInput
,TokenizedGenerateReqInput
,Req
, andSamplingBatchInfo
Examples
Frist define a dummy
AddLogitProcessor
endpoint usage example:
offline engine usage example
TODO
update the docs as suggested by @merrymercy
Checklist