Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the Self-RAG retrieval rate #113

Open
lhr-30 opened this issue Dec 24, 2024 · 2 comments
Open

About the Self-RAG retrieval rate #113

lhr-30 opened this issue Dec 24, 2024 · 2 comments

Comments

@lhr-30
Copy link

lhr-30 commented Dec 24, 2024

I found that the retrieval rate is really high for Self-RAG.

in active_pipeline.py line 164
judge_retrieve(self, input_prompts):

            all_pred_text = []
            all_pred_log_probs = []
            # For vllm, requesting too many logprobes can seriously affect speed
            # 20 probs is enough for calculate
            preds = self.generator.generate(input_prompts, return_raw_output=True, logprobs=20, max_tokens=1, skip_special_tokens=False)
            for single_pred in preds:
                pred_text = single_pred.outputs[0].text
                pred_log_probs = single_pred.outputs[0].logprobs
                all_pred_text.append(pred_text)
                all_pred_log_probs.append(pred_log_probs)

            retrieval_flags = []
            for idx, single_pred in enumerate(preds):
                if self.threshold is not None:
                    score_dict = {}
                    for tok, tok_id in self.ret_tokens.items():
                        if tok_id not in all_pred_log_probs[idx][0]:
                            score_dict[tok] = -100
                        else:
                            prob = all_pred_log_probs[idx][0][tok_id].logprob
                            score_dict[tok] = np.exp(prob)
                    do_retrieve = (
                        score_dict["[Retrieval]"] / (score_dict["[Retrieval]"] + score_dict["[No Retrieval]"])
                        > self.threshold
                    )
                else:
                    do_retrieve = "[Retrieval]" in all_pred_text[idx]

                retrieval_flags.append(do_retrieve)

if tok_id not in all_pred_log_probs[idx][0]:
score_dict[tok] = -100
else:
prob = all_pred_log_probs[idx][0][tok_id].logprob
score_dict[tok] = np.exp(prob)

Is this right? It should be score_dict[tok] = np.exp(-100).
I tried that but it seems that the retrieval rate would be too high. The retriever will be triggered almost all the time. Then it may be pointless to decide wether to retrieve or not?

@ignorejjj
Copy link
Member

Thanks for your question. We noticed this in previous experiments. There are indeed some problems here. If the score is set to -100, the do_retrieve score will become negative when calculated, which is less than self.threshold, we will fix it in next version. However, if this idea is followed, there should be a lot of examples in the dataset that are not retrieved, which is inconsistent with your experimental phenomenon.

So I guess it may be my explanation below. The logic here is actually to access the logprobs of the next token generated. Due to the limitations of vllm and performance reasons, we only get the logprobs of the top 20 tokens. In this case, if the top 20 tokens do not have the "[Retrieval]" and "[No Retrieval]" tags, the overall score will be 0.5, which is greater than the threshold of 0.2, and the query needs to be retrieved. In fact, we set the top 30000 tokens at the beginning, but this will cause the generation to be very slow.

I think this is a problem with this method or the trained model, which leads to inaccurate retrieval scores.

@lhr-30
Copy link
Author

lhr-30 commented Dec 25, 2024

I mean I changed -100 to np.exp(-100) and found that the retriever will be triggered for almost every request.
Also, I looked into the offical github repo for self-RAG, they even didn't use np.exp() for the computation of the probility

image

So I am confused about how to get correct results, and I don't know why we can decide on wether to retrieve or not based on this method.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants