Skip to content

Commit

Permalink
Include multiple annotators for WildBench (#3283)
Browse files Browse the repository at this point in the history
  • Loading branch information
liamjxu authored Jan 23, 2025
1 parent 2c14291 commit 80432dc
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 24 deletions.
71 changes: 48 additions & 23 deletions src/helm/benchmark/annotation/wildbench_annotator.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import re
from typing import Any
from importlib.resources import files
from typing import Dict

from helm.benchmark.adaptation.request_state import RequestState
from helm.benchmark.annotation.annotator import Annotator
from helm.benchmark.annotation.model_as_judge import _AnnotatorModelInfo
from helm.clients.auto_client import AutoClient
from helm.common.request import Request

Expand Down Expand Up @@ -38,28 +40,51 @@ def annotate(self, request_state: RequestState) -> Any:
.replace("{$model_output}", model_output_text)
.replace("{$checklist}", "\n".join(request_state.instance.extra_data["checklist"]))
)
annotator_request = Request(
model="openai/gpt-4o-2024-05-13",
model_deployment="openai/gpt-4o-2024-05-13",
prompt=annotator_prompt,
temperature=0.0,
max_tokens=2000,
)
annotator_response = self._auto_client.make_request(annotator_request)
if not annotator_response.success:
raise Exception(f"Annotation request failed: {annotator_response.error}")
assert len(annotator_response.completions) == 1
annotator_response_text = annotator_response.completions[0].text
annotator_response_parts = self._pattern.search(annotator_response_text)
if not annotator_response_parts:
raise ValueError(f"Malformed annotator response: {annotator_response_text}")

strengths = annotator_response_parts[1].strip()
weaknesses = annotator_response_parts[2].strip()
score_text = annotator_response_parts[3].strip().strip('"')
try:
score = float(score_text)
except ValueError:
raise ValueError(f"Malformed score '{score_text}' in annotator response: {annotator_response_text}")
SHORT_NAME_TO_MODEL_INFO: Dict[str, _AnnotatorModelInfo] = {
"gpt": _AnnotatorModelInfo(
model_name="openai/gpt-4o-2024-05-13", model_deployment="openai/gpt-4o-2024-05-13"
),
"llama": _AnnotatorModelInfo(
model_name="meta/llama-3.1-405b-instruct-turbo",
model_deployment="together/llama-3.1-405b-instruct-turbo",
),
"claude": _AnnotatorModelInfo(
model_name="anthropic/claude-3-5-sonnet-20241022",
model_deployment="anthropic/claude-3-5-sonnet-20241022",
),
}
all_strengths = []
all_weaknesses = []
all_scores = []
for annotator_model in SHORT_NAME_TO_MODEL_INFO:
annotator_model_info = SHORT_NAME_TO_MODEL_INFO[annotator_model]
annotator_request = Request(
model=annotator_model_info.model_name,
model_deployment=annotator_model_info.model_deployment,
prompt=annotator_prompt,
temperature=0.0,
max_tokens=2000,
)
annotator_response = self._auto_client.make_request(annotator_request)
if not annotator_response.success:
continue # skip this annotator if the request failed
assert len(annotator_response.completions) == 1
annotator_response_text = annotator_response.completions[0].text
annotator_response_parts = self._pattern.search(annotator_response_text)
if not annotator_response_parts:
continue # skip this annotator if the response is malformed

strengths = annotator_response_parts[1].strip()
weaknesses = annotator_response_parts[2].strip()
score_text = annotator_response_parts[3].strip().strip('"')
try:
score = float(score_text)
except ValueError:
continue # skip this annotator if the score is not a number

all_strengths.append(strengths)
all_weaknesses.append(weaknesses)
all_scores.append(score)

return {"strengths": strengths, "weaknesses": weaknesses, "score": score}
return {"strengths": all_strengths, "weaknesses": all_weaknesses, "score": all_scores}
5 changes: 4 additions & 1 deletion src/helm/benchmark/metrics/wildbench_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ def evaluate_generation(
eval_cache_path: str,
) -> List[Stat]:
assert request_state.annotations
score = request_state.annotations["wildbench"]["score"]
all_scores = request_state.annotations["wildbench"]["score"]
if len(all_scores) == 0:
raise ValueError("Could not compute WB Score because all annotators failed.")
score = sum(all_scores) / len(all_scores)
score_rescaled = (score - 1) / 9
return [
Stat(MetricName("wildbench_score")).add(score),
Expand Down

0 comments on commit 80432dc

Please sign in to comment.