-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #80 from wandb/weaveeval
W&B Weave based Evaluation
- Loading branch information
Showing
6 changed files
with
308 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import os | ||
os.environ["WANDB_ENTITY"] = "wandbot" | ||
|
||
import wandb | ||
import weave | ||
import pandas as pd | ||
from weave import Dataset | ||
|
||
from wandbot.evaluation.config import EvalConfig | ||
|
||
config = EvalConfig() | ||
|
||
wandb_project = config.wandb_project | ||
wandb_entity = config.wandb_entity | ||
|
||
eval_artifact = wandb.Api().artifact(config.eval_artifact) | ||
eval_artifact_dir = eval_artifact.download(root=config.eval_artifact_root) | ||
|
||
df = pd.read_json( | ||
f"{eval_artifact_dir}/{config.eval_annotations_file}", | ||
lines=True, | ||
orient="records", | ||
) | ||
df.insert(0, "id", df.index) | ||
|
||
correct_df = df[ | ||
(df["is_wandb_query"] == "YES") & (df["correctness"] == "correct") | ||
] | ||
|
||
data_rows = correct_df.to_dict('records') | ||
|
||
weave.init(wandb_project) | ||
|
||
# Create a dataset | ||
dataset = Dataset( | ||
name='wandbot_eval_data', | ||
rows=data_rows, | ||
) | ||
|
||
# Publish the dataset | ||
weave.publish(dataset) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
import os | ||
os.environ["WANDB_ENTITY"] = "wandbot" | ||
|
||
import json | ||
import httpx | ||
import weave | ||
import asyncio | ||
import requests | ||
from weave import Evaluation | ||
from weave import Model | ||
from llama_index.llms.openai import OpenAI | ||
|
||
from wandbot.evaluation.config import EvalConfig | ||
from wandbot.utils import get_logger | ||
|
||
from wandbot.evaluation.eval.correctness import ( | ||
CORRECTNESS_EVAL_TEMPLATE, | ||
WandbCorrectnessEvaluator, | ||
) | ||
|
||
logger = get_logger(__name__) | ||
config = EvalConfig() | ||
|
||
correctness_evaluator = WandbCorrectnessEvaluator( | ||
llm=OpenAI(config.eval_judge_model), | ||
eval_template=CORRECTNESS_EVAL_TEMPLATE, | ||
) | ||
|
||
wandb_project = config.wandb_project | ||
wandb_entity = config.wandb_entity | ||
|
||
weave.init(f"{wandb_entity}/{wandb_project}") | ||
|
||
|
||
@weave.op() | ||
async def get_answer(question: str, application: str = "api-eval") -> str: | ||
url = "http://0.0.0.0:8000/chat/query" | ||
payload = { | ||
"question": question, | ||
"application": application, | ||
"language": "en", | ||
} | ||
async with httpx.AsyncClient(timeout=200.0) as client: | ||
response = await client.post(url, json=payload) | ||
response_json = response.json() | ||
return json.dumps(response_json) | ||
|
||
|
||
@weave.op() | ||
async def get_eval_record( | ||
question: str, | ||
) -> dict: | ||
response = await get_answer(question) | ||
response = json.loads(response) | ||
return { | ||
"system_prompt": response["system_prompt"], | ||
"generated_answer": response["answer"], | ||
"retrieved_contexts": response["source_documents"], | ||
"model": response["model"], | ||
"total_tokens": response["total_tokens"], | ||
"prompt_tokens": response["prompt_tokens"], | ||
"completion_tokens": response["completion_tokens"], | ||
"time_taken": response["time_taken"], | ||
} | ||
|
||
|
||
class EvaluatorModel(Model): | ||
eval_judge_model: str = config.eval_judge_model | ||
|
||
@weave.op() | ||
async def predict(self, question: str) -> dict: | ||
# Model logic goes here | ||
prediction = await get_eval_record(question) | ||
return prediction | ||
|
||
|
||
@weave.op() | ||
async def get_answer_correctness( | ||
question: str, | ||
ground_truth: str, | ||
notes: str, | ||
model_output: dict | ||
) -> dict: | ||
result = await correctness_evaluator.aevaluate( | ||
query=question, | ||
response=model_output["generated_answer"], | ||
reference=ground_truth, | ||
contexts=model_output["retrieved_contexts"], | ||
reference_notes=notes, | ||
) | ||
return { | ||
"answer_correctness": result.dict()["passing"] | ||
} | ||
|
||
|
||
dataset_ref = weave.ref( | ||
"weave:///wandbot/wandbot-eval/object/wandbot_eval_data:eCQQ0GjM077wi4ykTWYhLPRpuGIaXbMwUGEB7IyHlFU" | ||
).get() | ||
question_rows = dataset_ref.rows | ||
question_rows = [ | ||
{ | ||
"question": row["question"], | ||
"ground_truth": row["answer"], | ||
"notes": row["notes"], | ||
} for row in question_rows | ||
] | ||
logger.info("Number of evaluation samples: %s", len(question_rows)) | ||
|
||
evaluation = Evaluation( | ||
dataset=question_rows, scorers=[get_answer_correctness] | ||
) | ||
|
||
if __name__ == "__main__": | ||
asyncio.run(evaluation.evaluate(EvaluatorModel())) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
import asyncio | ||
from typing import Any, Optional, Sequence | ||
|
||
import regex as re | ||
from llama_index.core.evaluation import CorrectnessEvaluator, EvaluationResult | ||
|
||
from wandbot.evaluation.eval.utils import ( | ||
make_eval_template, | ||
safe_parse_eval_response, | ||
) | ||
|
||
import wandb | ||
import weave | ||
|
||
SYSTEM_TEMPLATE = """You are a Weight & Biases support expert tasked with evaluating the correctness of answers to questions asked by users to a a technical support chatbot. | ||
You are given the following information: | ||
- a user query, | ||
- the documentation used to generate the answer | ||
- a reference answer | ||
- the reason why the reference answer is correct, and | ||
- a generated answer. | ||
Your job is to judge the relevance and correctness of the generated answer. | ||
- Consider whether the answer addresses all aspects of the question. | ||
- The generated answer must provide only correct information according to the documentation. | ||
- Compare the generated answer to the reference answer for completeness and correctness. | ||
- Output a score and a decision that represents a holistic evaluation of the generated answer. | ||
- You must return your response only in the below mentioned format. Do not return answers in any other format. | ||
Follow these guidelines for scoring: | ||
- Your score has to be between 1 and 3, where 1 is the worst and 3 is the best. | ||
- If the generated answer is not correct in comparison to the reference, you should give a score of 1. | ||
- If the generated answer is correct in comparison to the reference but contains mistakes, you should give a score of 2. | ||
- If the generated answer is correct in comparision to the reference and completely answer's the user's query, you should give a score of 3. | ||
Output your final verdict by strictly following JSON format: | ||
{{ | ||
"reason": <<Provide a brief explanation for your decision here>>, | ||
"score": <<Provide a score as per the above guidelines>>, | ||
"decision": <<Provide your final decision here, either 'correct', or 'incorrect'>> | ||
}} | ||
Example Response 1: | ||
{{ | ||
"reason": "The generated answer has the exact details as the reference answer and completely answer's the user's query.", | ||
"score": 3, | ||
"decision": "correct" | ||
}} | ||
Example Response 2: | ||
{{ | ||
"reason": "The generated answer doesn't match the reference answer, and deviates from the documentation provided", | ||
"score": 1, | ||
"decision": "incorrect" | ||
}} | ||
Example Response 3: | ||
{{ | ||
"reason": "The generated answer follows the same steps as the reference answer. However, it includes assumptions about methods that are not mentioned in the documentation.", | ||
"score": 2, | ||
"decision": "incorrect" | ||
}} | ||
""" | ||
|
||
|
||
USER_TEMPLATE = """ | ||
## User Query | ||
{query} | ||
## Documentation | ||
{context_str} | ||
## Reference Answer | ||
{reference_answer} | ||
## Reference Correctness Reason | ||
{reference_notes} | ||
## Generated Answer | ||
{generated_answer} | ||
""" | ||
|
||
CORRECTNESS_EVAL_TEMPLATE = make_eval_template(SYSTEM_TEMPLATE, USER_TEMPLATE) | ||
|
||
|
||
class WandbCorrectnessEvaluator(CorrectnessEvaluator): | ||
@weave.op() | ||
async def aevaluate( | ||
self, | ||
query: Optional[str] = None, | ||
response: Optional[str] = None, | ||
contexts: Optional[Sequence[str]] = None, | ||
reference: Optional[str] = None, | ||
sleep_time_in_seconds: int = 0, | ||
**kwargs: Any, | ||
) -> EvaluationResult: | ||
await asyncio.sleep(sleep_time_in_seconds) | ||
|
||
if query is None or response is None or reference is None: | ||
print(query, response, reference, flush=True) | ||
raise ValueError("query, response, and reference must be provided") | ||
|
||
eval_response = await self._llm.apredict( | ||
prompt=self._eval_template, | ||
query=query, | ||
generated_answer=response, | ||
reference_answer=reference, | ||
context_str=re.sub( | ||
"\n+", "\n", "\n---\n".join(contexts) if contexts else "" | ||
), | ||
reference_notes=kwargs.get("reference_notes", ""), | ||
) | ||
|
||
passing, reasoning, score = await safe_parse_eval_response( | ||
eval_response, "correct" | ||
) | ||
|
||
return EvaluationResult( | ||
query=query, | ||
response=response, | ||
passing=passing, | ||
score=score, | ||
feedback=reasoning, | ||
) |