Skip to content

Commit

Permalink
feat: introduce results attribute on MMLU evaluator
Browse files Browse the repository at this point in the history
In order to test the validity of our MMLU results or get information on prior runs,
we need to be able to access the full set of results from the lm_eval.evaluator.simple_evaluate
API. This commit provides that ability by adding a results attribute on the MMLUEvaluator class
and storing the results there.

Signed-off-by: Oleg S <[email protected]>
  • Loading branch information
RobotSail committed Dec 13, 2024
1 parent fd78adf commit ad12276
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
## 0.4.2

* Adds the ability to provide a custom system prompt to the MMLU-based evaluators. When a system prompt is provided, LM-eval applies the chat template under the hood, else it will pass the model a barebones prompt.
* Adds an `extra_args` parameter to the `.run` method of all MMLU-based evaluators. This way, consumers are able to directly pass any additional arguments they want through to the `lm_eval.evaluators.simple_evaluate` function.

## 0.4

Expand Down
53 changes: 52 additions & 1 deletion scripts/test_mmlu.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,41 @@
# Standard
from typing import Dict, List, Tuple, TypedDict

# First Party
from instructlab.eval.mmlu import MMLUEvaluator

SYSTEM_PROMPT = """I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant."""


class MMLUSample(TypedDict):
"""
Example of a single sample returned from lm_eval when running MMLU.
This is not a comprehensive type, just the subset of fields we care about for this test.
"""

# Arguments is the list of (prompt, answer) pairs passed to MMLU as few-shot samples.
# They will not be present with few_shot=0
arguments: List[Tuple[str, str]]


def all_samples_contain_system_prompt(
samples: Dict[str, List[MMLUSample]], prompt: str
) -> bool:
"""
Given a mapping of evaluation --> list of results, validates that all few-shot examples
included the system prompt
"""
for topic, samples_set in samples.items():
for sample in samples_set:
for mmlu_prompt, _ in sample["arguments"]:
if prompt not in mmlu_prompt:
# we are looking for the exact system prompt, so no need to convert to normalize to lowercase
print(f"found a sample in the '{topic}' MMLU topic set")
return False

return True


def test_minimal_mmlu():
print("===> Executing 'test_minimal_mmlu'...")
try:
Expand All @@ -14,9 +46,28 @@ def test_minimal_mmlu():
tasks=tasks,
system_prompt=SYSTEM_PROMPT,
)
overall_score, individual_scores = mmlu.run()
overall_score, individual_scores = mmlu.run(
extra_args={"log_samples": True, "write_out": True}
)
samples = mmlu.results["samples"]

print(overall_score)
print(individual_scores)

# we need n-shots > 1 to be able to validate the inclusion of the system prompt
eligible_samples = {
topic: samples[topic]
for topic, shot in mmlu.results["n-shot"].items()
if shot > 1
}
if eligible_samples:
if not all_samples_contain_system_prompt(eligible_samples, SYSTEM_PROMPT):
return False
else:
print(
"MMLU was run in zero-shot mode, cannot confirm that system prompt was included, skipping check..."
)

except Exception as exc:
print(f"'test_minimal_mmlu' failed: {exc}")
return False
Expand Down
60 changes: 42 additions & 18 deletions src/instructlab/eval/mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
"""

# Standard
from typing import Optional, Union
from typing import Any, Dict, Optional, Union
import os

# Third Party
from lm_eval.evaluator import simple_evaluate # type: ignore
from lm_eval.tasks import TaskManager # type: ignore
from lm_eval.evaluator import simple_evaluate
from lm_eval.tasks import TaskManager
import torch

# First Party
Expand Down Expand Up @@ -103,6 +103,7 @@ class AbstractMMLUEvaluator(Evaluator):
batch_size batch size for evaluation. Valid values are a positive integer or 'auto' to select the largest batch size that will fit in memory, or 'auto:N' to reselect the largest batch size N times'.
device PyTorch device (e.g. "cpu" or "cuda:0") for running models
system_prompt system prompt to be used when applying the chat template
results full output from the `lm_eval.evaluator.simple_evaluate` function after MMLU has run.
"""

def __init__(
Expand All @@ -124,18 +125,33 @@ def __init__(
self.few_shots = few_shots
self.batch_size = batch_size
self.device = device
self._results = None

def run(self, server_url: str | None = None) -> tuple:
@property
def results(self) -> Dict[str, Any] | None:
"""
Returns the results of the last MMLU evaluation, if one has taken place.
Returns:
Dict[str, Any] | None: The output from `lm_eval.evaluator.simple_evaluate`
"""
return self._results

def run(
self, server_url: str | None = None, extra_args: Dict[str, Any] | None = None
) -> tuple:
"""
Runs evaluation
Attributes
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
extra_args Dictionary containing any extra arguments to be passed into the lm_eval `lm_eval.evaluator.simple_evaluate` function.
Returns:
overall_score Average score for the task group
individual_scores Individual scores for each task in the task group
"""
extra_args = {} if not extra_args else extra_args
logger.debug(locals())

# TODO: make this a parameter for class?
Expand All @@ -156,7 +172,10 @@ def run(self, server_url: str | None = None) -> tuple:

return overall_score, individual_scores

def _run_mmlu(self, server_url: str | None = None) -> dict:
def _run_mmlu(
self, server_url: str | None = None, extra_args: Dict[str, Any] | None = None
) -> dict:
extra_args = {} if not extra_args else extra_args
if server_url is not None:
# Requires lm_eval >= 0.4.4
model_args = f"base_url={server_url}/completions,model={self.model_path},tokenizer_backend=huggingface"
Expand All @@ -172,19 +191,24 @@ def _run_mmlu(self, server_url: str | None = None) -> dict:
raise InvalidTasksDirError(self.tasks_dir)
tm = TaskManager(verbosity="DEBUG", include_path=self.tasks_dir)
should_apply_chat_template = self.system_prompt is not None
mmlu_output = self._simple_evaluate_with_error_handling(
model=model,
model_args=model_args,
tasks=self.tasks,
num_fewshot=self.few_shots,
batch_size=self.batch_size,
device=self.device,
task_manager=tm,
system_instruction=self.system_prompt,
apply_chat_template=should_apply_chat_template,
)
results = mmlu_output["results"]
return results

# configure the args here so users can override them as necessary
simple_evaluate_kwargs = {
"model": model,
"model_args": model_args,
"tasks": self.tasks,
"num_fewshot": self.few_shots,
"batch_size": self.batch_size,
"device": self.device,
"task_manager": tm,
"system_instruction": self.system_prompt,
"apply_chat_template": should_apply_chat_template,
}
simple_evaluate_kwargs.update(extra_args)

results = self._simple_evaluate_with_error_handling(**simple_evaluate_kwargs)
self._results = results
return results["results"]

# This method converts general errors from simple_evaluate
# into a more user-understandable error
Expand Down

0 comments on commit ad12276

Please sign in to comment.