-
Notifications
You must be signed in to change notification settings - Fork 375
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backtracking in beam search #151
Comments
@shubhamagarwal92 thanks for pointing this out. I'll check their implementation and see what's different. Working on the test to make it more deterministic. Will test more beam sizes too. |
Hi, |
I studied the codes these days, and I thought you can use the torch.repeat_interleave. Such as follow: |
@Mehrad0711 maybe you can try and integrate BS from allennlp; their implementation here |
Hey @Mehrad0711 @shubhamagarwal92 sorry haven't gotten the time to work on this yet. You're welcome to submit a PR :) |
Any update on this issue ? |
Compared to OpenNMT, why do we need this block which handles the dropped sequences that see EOS earlier. (This is not there in their beam search implementation.) They are also doing a similar process: not letting the EOS have children here. However they have this end condition when EOS is at the top. They construct back the hypothesis using get_hyp function.
More specifically, can you explain elaborately what we are doing here.
I understand why we need to handle EOS sequences since we have their information in backtracking variables. But why do we need to "replace the one with the lowest survived sequence score with the new ended sequences"? AFAIK, this res_k_idx is tracking which beam (from the end) can we replace the information (the two conditions specified in the comments). However, we are not replacing the contents of the beam which got EOS, i.e:
I understand that after this process all the beams remain static and we use index_select at each step to select the top beams.
Also, the unit test for top_k_decoder is not deterministic. Fails when batch_size>2 and also sometimes when batch_size==2.
The text was updated successfully, but these errors were encountered: