Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG Report] jacobi_greedy_search_multilevel function bug #56

Open
yangbohust opened this issue Apr 8, 2024 · 1 comment
Open

[BUG Report] jacobi_greedy_search_multilevel function bug #56

yangbohust opened this issue Apr 8, 2024 · 1 comment

Comments

@yangbohust
Copy link

yangbohust commented Apr 8, 2024

When the first GUESS_SIZE elements of the correct list and the myguess list are consistent, it means that all guesses have been made. At this time, the last element of the correct list should also be the correct token, so it should be added to the hits list.

https://github.com/hao-ai-lab/LookaheadDecoding/blob/9d50de4a81d1b473bfce104ace18fbbbb6dc3255/lade/decoding.py#L1068C1-L1085C88

original code

hits = [first_guess] + [0] * (GUESS_SIZE - 1)
            #multi-level window is filled
            #match guess tokens 
            if guess_tokens is not None:
                guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
                for eg in range(len(guess_results) // GUESS_SIZE):
                    egx = eg * GUESS_SIZE
                    correct = [first_guess] + guess_results[egx:egx + GUESS_SIZE]
                    myguess = guess_tokens[egx:egx + GUESS_SIZE]
                    gg = 0
                    for gg in range(len(myguess)):
                        if myguess[gg] != correct[gg]:
                            break 
                    if gg > max_hit:
                        max_hit = gg 
                        max_hit_idx = eg 
                        hits[:max_hit + 1] = correct[:max_hit + 1]
            #max_hit is the length of longest accepted sequence in verification branch 

Modified code

hits = [first_guess] + [0] * GUESS_SIZE
            #multi-level window is filled
            #match guess tokens 
            if guess_tokens is not None:
                guess_results = torch.argmax(outputs.guess_logits, dim=-1)[0].tolist()
                for eg in range(len(guess_results) // GUESS_SIZE):
                    egx = eg * GUESS_SIZE
                    correct = [first_guess] + guess_results[egx:egx + GUESS_SIZE]
                    myguess = guess_tokens[egx:egx + GUESS_SIZE]
                    gg = 0
                    while gg < len(myguess):
                        if myguess[gg] != correct[gg]:
                            break
                        gg += 1
                    if gg > max_hit:
                        max_hit = gg 
                        max_hit_idx = eg 
                        hits[:max_hit + 1] = correct[:max_hit + 1]
            #max_hit is the length of longest accepted sequence in verification branch 
@k-l-lambda
Copy link

I also found this bug. This is significant, that means in some situation a hit token was wasted. And this fix will improve performance. Experiment data in paper can be updated since this.

I committed a pull request.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants