Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using "beam search" strategy while generating the responses #2534

Closed
SachinVashisth opened this issue Dec 31, 2024 · 1 comment
Closed

Using "beam search" strategy while generating the responses #2534

SachinVashisth opened this issue Dec 31, 2024 · 1 comment
Labels
🙋 help from community wanted Open invitation for community members to contribute 🏋 PPO Related to PPO

Comments

@SachinVashisth
Copy link

Hi

I am using flan-t5-xl to generate the output.
When I use the function ppo_trainer.generate(....), it gives me the desired output but I guess it is the top beam or the best output.
I am trying to generate output for 4 beams (currently using this custom generate function):

def select(list_data, model, tokenizer, strategy = "beam"):
    device = 0 if torch.cuda.is_available() else "cpu"
    beams, num_seq = 5, 4
    batch = tokenizer(list_data, return_tensors="pt", padding=True)
    if strategy == "beam":
        generated = model.generate(
            input_ids = batch["input_ids"].to(device), attention_mask = batch["attention_mask"].to(device),
            num_beams=beams, early_stopping=True, num_return_sequences=num_seq, max_new_tokens=60, min_length=5
        )
    final_list = [tokenizer.decode(g, skip_special_tokens=True, clean_up_tokenization_spaces=True).lower() for g in generated]
    return final_list

Is it possible to generate output for multiple beams using the ppo_trainer.generate(....) function?

@August-murr August-murr added 🙋 help from community wanted Open invitation for community members to contribute 🏋 PPO Related to PPO labels Jan 1, 2025
@edbeeching
Copy link
Collaborator

While we do not expose this functionality at the moment, if you fork trl you should be able to add your beam search options to the GenerationConfig here:

generation_config = GenerationConfig(

If you find this change beneficial, we would welcome a PR to expose the options. I will close the issue but feel free to reopen if needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
🙋 help from community wanted Open invitation for community members to contribute 🏋 PPO Related to PPO
Projects
None yet
Development

No branches or pull requests

3 participants