diff --git a/optillm/bon.py b/optillm/bon.py index 8ee752a..7f38ed9 100644 --- a/optillm/bon.py +++ b/optillm/bon.py @@ -9,18 +9,30 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st {"role": "user", "content": initial_query}] completions = [] - - response = client.chat.completions.create( - model=model, - messages=messages, - max_tokens=4096, - n=n, - temperature=1 - ) - completions = [choice.message.content for choice in response.choices] + + try: + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=4096, + n=n, + temperature=1 + ) + completions = [choice.message.content for choice in response.choices] + bon_completion_tokens += response.usage.completion_tokens + except: + for _ in range(n): + response = client.chat.completions.create( + model=model, + messages=messages, + max_tokens=4096, + n=1, + temperature=1 + ) + completions.extend([choice.message.content for choice in response.choices]) + bon_completion_tokens += response.usage.completion_tokens logger.info(f"Generated {len(completions)} initial completions. Tokens used: {response.usage.completion_tokens}") - bon_completion_tokens += response.usage.completion_tokens - + # Rate the completions rating_messages = messages.copy() rating_messages.append({"role": "system", "content": "Rate the following responses on a scale from 0 to 10, where 0 is poor and 10 is excellent. Consider factors such as relevance, coherence, and helpfulness. Respond with only a number."})