Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add sampler custom logits processor #2396

Merged

Conversation

hongpeng-guo
Copy link
Contributor

@hongpeng-guo hongpeng-guo commented Dec 8, 2024

Motivation

This PR tries to support custom logits processor registered by users, so users can easily implement their custom sampling methods without the need to change the sglang code.

Related Issue

#2291

Modifications

  1. Introduce the common abstract interface CustomLogitProcessor as in sglang/srt/sampling/custom_logit_processor.py. This interface contains a (1) Callable function to process the logtis, and (2) from_str and to_str methods that any user defined subclass can be serialized and passed into the server using request.post(url, json=data) to the /generate endpoint.
  2. Add custom_params as a field of SamplingParams
  3. Add custom_logit_processor as a field to GenerateReqInput, TokenizedGenerateReqInput, Req, and SamplingBatchInfo

Examples

Frist define a dummy AddLogitProcessor

from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor

class AddLogitProcessor(CustomLogitProcessor):

    def __call__(self, logits, custom_param_list):
        import torch 
        assert logits.shape[0] == len(custom_param_list)
        key = 'arg1'

        merged_params = torch.tensor(
            [custom_param_list[i][key] for i in range(len(custom_param_list))], dtype=torch.float
        ).to(device=logits.device, non_blocking=True)
        return logits + merged_params

endpoint usage example:

import requests
    
url = "http://localhost:30000/generate"
data = {
    "text": "What is the capital of France?",
    "sampling_params": {        
        "custom_params": {
            "arg1": 5.0,
        },
    },
    "custom_logit_processor": AddLogitProcessor().to_str(),
}

response = requests.post(url, json=data)
print(response.json())

offline engine usage example

import sglang as sgl

llm = sgl.Engine(model_path="meta-llama/Meta-Llama-3.1-8B-Instruct")
prompt = "The president of the United States is"
sampling_params = {"temperature": 0.8, "top_p": 0.95, "custom_params": {"arg1": 5.0}}

output = llm.generate(
    prompt, 
    sampling_params,
    custom_logit_processor=AddLogitProcessor().to_str()
)
print(output['text'])

TODO

update the docs as suggested by @merrymercy

Checklist

  • Format your code according to the Contributor Guide.
  • Add unit tests as outlined in the Contributor Guide.
  • Update documentation as needed, including docstrings or example tutorials.

@hongpeng-guo hongpeng-guo marked this pull request as draft December 8, 2024 09:39
@hongpeng-guo hongpeng-guo changed the title Add sampler logit processor [WIP] Add sampler logit processor Dec 8, 2024
@merrymercy
Copy link
Contributor

Thanks for taking this. Can you add some end-to-end tests and examples?

@hongpeng-guo
Copy link
Contributor Author

Thanks for taking this. Can you add some end-to-end tests and examples?

@merrymercy Thanks for taking a look on this. I am still trying to understand the appropriate layer for the user to register their customized_logit_processor. The goal is to enable customized_logit_processor functionality without requiring changes to the internal sglang codebase. To achieve this, the function registration should occur at the API layer. The customized_logit_processor_fn and custom_params will then be passed from the program driver to the internal modules, such as Sampler and SampleParams.

If the above seems correct, I will try to get it reviewable within this week.

@merrymercy
Copy link
Contributor

merrymercy commented Dec 26, 2024

Your proposal sounds good. We should be able to register this function with all interfaces.
Say, we should be able to support it with the native /generate api and the offline Engine API.

@merrymercy merrymercy self-assigned this Dec 26, 2024
Copy link
Contributor Author

@hongpeng-guo hongpeng-guo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are a few unit test on performance and accuracy being a bit flaky. they seem not to be directly related to this PR, though.

Copy link
Contributor

@merrymercy merrymercy left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are almost there! Some final comments.

python/sglang/srt/managers/schedule_batch.py Outdated Show resolved Hide resolved
@hongpeng-guo
Copy link
Contributor Author

We are almost there! Some final comments.

Thanks a lot for shepherding this! Just handled the comments. PTAL

@@ -76,15 +88,48 @@ def from_schedule_batch(
[r.sampling_params.min_p for r in reqs], dtype=torch.float
).to(device, non_blocking=True)

# Check if any request has custom logit processor
has_custom_logit_processor = any(r.custom_logit_processor for r in reqs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can even skip this for loop if custom logit process is not enabled by the server.

self, unfinished_indices: List[int], new_indices: torch.Tensor
):
"""Filter the custom logit processor and custom params"""
if not self.custom_logit_processor:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems not needed. If self.custom_logit_processor is None, why will it go to this function?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed, I can remove this or make it a simple assert.

@merrymercy merrymercy merged commit e403d23 into sgl-project:main Jan 19, 2025
16 checks passed
@merrymercy
Copy link
Contributor

@hongpeng-guo Thanks. It is merged.

Some follow-up items:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants