You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I found that the retrieval rate is really high for Self-RAG.
in active_pipeline.py line 164
judge_retrieve(self, input_prompts):
all_pred_text = []
all_pred_log_probs = []
# For vllm, requesting too many logprobes can seriously affect speed
# 20 probs is enough for calculate
preds = self.generator.generate(input_prompts, return_raw_output=True, logprobs=20, max_tokens=1, skip_special_tokens=False)
for single_pred in preds:
pred_text = single_pred.outputs[0].text
pred_log_probs = single_pred.outputs[0].logprobs
all_pred_text.append(pred_text)
all_pred_log_probs.append(pred_log_probs)
retrieval_flags = []
for idx, single_pred in enumerate(preds):
if self.threshold is not None:
score_dict = {}
for tok, tok_id in self.ret_tokens.items():
if tok_id not in all_pred_log_probs[idx][0]:
score_dict[tok] = -100
else:
prob = all_pred_log_probs[idx][0][tok_id].logprob
score_dict[tok] = np.exp(prob)
do_retrieve = (
score_dict["[Retrieval]"] / (score_dict["[Retrieval]"] + score_dict["[No Retrieval]"])
> self.threshold
)
else:
do_retrieve = "[Retrieval]" in all_pred_text[idx]
retrieval_flags.append(do_retrieve)
if tok_id not in all_pred_log_probs[idx][0]:
score_dict[tok] = -100
else:
prob = all_pred_log_probs[idx][0][tok_id].logprob
score_dict[tok] = np.exp(prob)
Is this right? It should be score_dict[tok] = np.exp(-100).
I tried that but it seems that the retrieval rate would be too high. The retriever will be triggered almost all the time. Then it may be pointless to decide wether to retrieve or not?
The text was updated successfully, but these errors were encountered:
Thanks for your question. We noticed this in previous experiments. There are indeed some problems here. If the score is set to -100, the do_retrieve score will become negative when calculated, which is less than self.threshold, we will fix it in next version. However, if this idea is followed, there should be a lot of examples in the dataset that are not retrieved, which is inconsistent with your experimental phenomenon.
So I guess it may be my explanation below. The logic here is actually to access the logprobs of the next token generated. Due to the limitations of vllm and performance reasons, we only get the logprobs of the top 20 tokens. In this case, if the top 20 tokens do not have the "[Retrieval]" and "[No Retrieval]" tags, the overall score will be 0.5, which is greater than the threshold of 0.2, and the query needs to be retrieved. In fact, we set the top 30000 tokens at the beginning, but this will cause the generation to be very slow.
I think this is a problem with this method or the trained model, which leads to inaccurate retrieval scores.
I mean I changed -100 to np.exp(-100) and found that the retriever will be triggered for almost every request.
Also, I looked into the offical github repo for self-RAG, they even didn't use np.exp() for the computation of the probility
So I am confused about how to get correct results, and I don't know why we can decide on wether to retrieve or not based on this method.
I found that the retrieval rate is really high for Self-RAG.
in active_pipeline.py line 164
judge_retrieve(self, input_prompts):
if tok_id not in all_pred_log_probs[idx][0]:
score_dict[tok] = -100
else:
prob = all_pred_log_probs[idx][0][tok_id].logprob
score_dict[tok] = np.exp(prob)
Is this right? It should be score_dict[tok] = np.exp(-100).
I tried that but it seems that the retrieval rate would be too high. The retriever will be triggered almost all the time. Then it may be pointless to decide wether to retrieve or not?
The text was updated successfully, but these errors were encountered: