Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restore the ability to draw multiple samples with Open Source models #416

Closed
rlouf opened this issue Dec 8, 2023 · 5 comments · Fixed by #533
Closed

Restore the ability to draw multiple samples with Open Source models #416

rlouf opened this issue Dec 8, 2023 · 5 comments · Fixed by #533

Comments

@rlouf
Copy link
Member

rlouf commented Dec 8, 2023

This was removed in #366 to simplify the PR, and should be added again. This will require to be careful with the shape and an added dimension will need to be added to take sample shape into account. The mechanism implemented there can be re-used when implementing Beam Search #258

@lapp0
Copy link
Contributor

lapp0 commented Jan 3, 2024

I'm looking to implement beam search.

I'm wondering whether we could simply use the sequence tokens as the ID of the sequence?

8b1ff9a#diff-f65ffb5f52b2e358c713ccb8f32a700769426c6c8b655f689e3cdccae07d22ac

For 1000 token sequences, I can generate 25,000 keys on my machine per second, so it shouldn't be a substantial bottleneck.

@rlouf
Copy link
Member Author

rlouf commented Jan 3, 2024

My understanding is that vLLM needs sequence ids because they're doing continuous batching, and we wouldn't need to assign an id to sequences here.

I'm still hesitating between using one big tensor of shape n_samples x n_batch x n_token_ids like we did before the refactor, or breaking the sequences down like this:

from typing import List


class Sequence:
    prompt_token_ids: List[int]
    generated_token_ids: List[int]
    logprob: float

    @property
    def token_ids(self):
        return prompt_token_ids + generated_token_ids
	
    def add_token_id(self, token_id, token_logprob):
        logprob += token_logprob
        self.generated_token_ids.append(token_id)

class Generation:
    sequences: List[Sequence]

For beam search, a Sequence would be a beam and Generation correspond to an input prompt. We have plans for better KV cache management in the future and this would definitely simplify things.

This is the way vLLM does it, and they create new tensors at each step. We would need to determine the overhead of creating new tensors at each step before moving forward.

What do you think?

@rlouf
Copy link
Member Author

rlouf commented Jan 3, 2024

Would you mind copy/pasting your comment into a discussion and I'll answer there? Just so we stay on topic here and your comment is easier to find for future readers.

@lapp0
Copy link
Contributor

lapp0 commented Jan 3, 2024

#501

@rlouf
Copy link
Member Author

rlouf commented Jan 11, 2024

Here are the changes that need to be implemented in order to restore the ability to generate several samples for each sequence:

  1. Update sequence_generator to it expands the prompt ids from (n_batches, n_tokens) to (n_samples, n_batches, n_tokens) by duplicating the prompts. A single prompt with 10 tokens would lead to token_ids with shape (1, 1, 10), a batch of 3 prompts to a shape (1, 3, 10) (with padding) and 7 samples for a batch of 3 prompts with shape (7, 3, 10). We keep singleton dimensions to simplify the code. Same with fsm_states.
  2. Reshape the token_ids array to (n_batches * n_samples, n_tokens) before calling the token generator
  3. Update update_token_ids, expand_attention_masks and update_fsm_states
  4. Decode the result.
  5. Reshape the decoded sequences to (n_samples, n_batches, n_tokens). Remove singleton dimensions. Note that when we return a Sequence instance instead of just text in the future we will only remove singleton dimensions when printing or extracting the text, but will keep them for the token_ids and logprobs.

We will need to add tests for init_generator_state, sequence_generator and the sampling algorithms.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging a pull request may close this issue.

2 participants