Skip to content

Commit

Permalink
Merge pull request #197 from RobotSail/fix-mmlu
Browse files Browse the repository at this point in the history
Allows MMLU to have the system_prompt provided to it
  • Loading branch information
mergify[bot] authored Dec 13, 2024
2 parents 4cf3e14 + ad12276 commit c086116
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 28 deletions.
3 changes: 2 additions & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,8 @@ disable=raw-checker-failed,
pointless-statement,
wrong-import-order,
line-too-long,
dangerous-default-value
dangerous-default-value,
too-many-instance-attributes

# Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option
Expand Down
2 changes: 2 additions & 0 deletions .spellcheck-en-custom.txt
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,5 @@ TODO
tox
venv
vllm
barebones
LM
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
## 0.4.2

* Adds the ability to provide a custom system prompt to the MMLU-based evaluators. When a system prompt is provided, LM-eval applies the chat template under the hood, else it will pass the model a barebones prompt.
* Adds an `extra_args` parameter to the `.run` method of all MMLU-based evaluators. This way, consumers are able to directly pass any additional arguments they want through to the `lm_eval.evaluators.simple_evaluate` function.

## 0.4

* Added ability to specify a custom http client to MT-Bench
Expand Down
61 changes: 59 additions & 2 deletions scripts/test_mmlu.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,73 @@
# Standard
from typing import Dict, List, Tuple, TypedDict

# First Party
from instructlab.eval.mmlu import MMLUEvaluator

SYSTEM_PROMPT = """I am, Red Hat® Instruct Model based on Granite 7B, an AI language model developed by Red Hat and IBM Research, based on the Granite-7b-base language model. My primary function is to be a chat assistant."""


class MMLUSample(TypedDict):
"""
Example of a single sample returned from lm_eval when running MMLU.
This is not a comprehensive type, just the subset of fields we care about for this test.
"""

# Arguments is the list of (prompt, answer) pairs passed to MMLU as few-shot samples.
# They will not be present with few_shot=0
arguments: List[Tuple[str, str]]


def all_samples_contain_system_prompt(
samples: Dict[str, List[MMLUSample]], prompt: str
) -> bool:
"""
Given a mapping of evaluation --> list of results, validates that all few-shot examples
included the system prompt
"""
for topic, samples_set in samples.items():
for sample in samples_set:
for mmlu_prompt, _ in sample["arguments"]:
if prompt not in mmlu_prompt:
# we are looking for the exact system prompt, so no need to convert to normalize to lowercase
print(f"found a sample in the '{topic}' MMLU topic set")
return False

return True


def test_minimal_mmlu():
print("===> Executing 'test_minimal_mmlu'...")
try:
model_path = "instructlab/granite-7b-lab"
tasks = ["mmlu_anatomy", "mmlu_astronomy"]
mmlu = MMLUEvaluator(model_path=model_path, tasks=tasks)
overall_score, individual_scores = mmlu.run()
mmlu = MMLUEvaluator(
model_path=model_path,
tasks=tasks,
system_prompt=SYSTEM_PROMPT,
)
overall_score, individual_scores = mmlu.run(
extra_args={"log_samples": True, "write_out": True}
)
samples = mmlu.results["samples"]

print(overall_score)
print(individual_scores)

# we need n-shots > 1 to be able to validate the inclusion of the system prompt
eligible_samples = {
topic: samples[topic]
for topic, shot in mmlu.results["n-shot"].items()
if shot > 1
}
if eligible_samples:
if not all_samples_contain_system_prompt(eligible_samples, SYSTEM_PROMPT):
return False
else:
print(
"MMLU was run in zero-shot mode, cannot confirm that system prompt was included, skipping check..."
)

except Exception as exc:
print(f"'test_minimal_mmlu' failed: {exc}")
return False
Expand Down
86 changes: 63 additions & 23 deletions src/instructlab/eval/mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
"""

# Standard
from typing import Optional, Union
from typing import Any, Dict, Optional, Union
import os

# Third Party
from lm_eval.evaluator import simple_evaluate # type: ignore
from lm_eval.tasks import TaskManager # type: ignore
from lm_eval.evaluator import simple_evaluate
from lm_eval.tasks import TaskManager
import torch

# First Party
Expand Down Expand Up @@ -102,6 +102,8 @@ class AbstractMMLUEvaluator(Evaluator):
few_shots number of examples
batch_size batch size for evaluation. Valid values are a positive integer or 'auto' to select the largest batch size that will fit in memory, or 'auto:N' to reselect the largest batch size N times'.
device PyTorch device (e.g. "cpu" or "cuda:0") for running models
system_prompt system prompt to be used when applying the chat template
results full output from the `lm_eval.evaluator.simple_evaluate` function after MMLU has run.
"""

def __init__(
Expand All @@ -113,26 +115,43 @@ def __init__(
few_shots: int = 5,
batch_size: Optional[Union[int, str]] = "auto",
device: str = ("cuda" if torch.cuda.is_available() else "cpu"),
system_prompt: Optional[str] = None,
) -> None:
self.model_path = model_path
self.system_prompt = system_prompt
self.tasks_dir = tasks_dir
self.tasks = tasks
self.model_dtype = model_dtype
self.few_shots = few_shots
self.batch_size = batch_size
self.device = device
self._results = None

def run(self, server_url: str | None = None) -> tuple:
@property
def results(self) -> Dict[str, Any] | None:
"""
Returns the results of the last MMLU evaluation, if one has taken place.
Returns:
Dict[str, Any] | None: The output from `lm_eval.evaluator.simple_evaluate`
"""
return self._results

def run(
self, server_url: str | None = None, extra_args: Dict[str, Any] | None = None
) -> tuple:
"""
Runs evaluation
Attributes
server_url Model server endpoint (Ex: http://localhost:8000/v1) for the model being evaluated
extra_args Dictionary containing any extra arguments to be passed into the lm_eval `lm_eval.evaluator.simple_evaluate` function.
Returns:
overall_score Average score for the task group
individual_scores Individual scores for each task in the task group
"""
extra_args = {} if not extra_args else extra_args
logger.debug(locals())

# TODO: make this a parameter for class?
Expand All @@ -153,7 +172,10 @@ def run(self, server_url: str | None = None) -> tuple:

return overall_score, individual_scores

def _run_mmlu(self, server_url: str | None = None) -> dict:
def _run_mmlu(
self, server_url: str | None = None, extra_args: Dict[str, Any] | None = None
) -> dict:
extra_args = {} if not extra_args else extra_args
if server_url is not None:
# Requires lm_eval >= 0.4.4
model_args = f"base_url={server_url}/completions,model={self.model_path},tokenizer_backend=huggingface"
Expand All @@ -168,17 +190,25 @@ def _run_mmlu(self, server_url: str | None = None) -> dict:
if not os.access(self.tasks_dir, os.R_OK):
raise InvalidTasksDirError(self.tasks_dir)
tm = TaskManager(verbosity="DEBUG", include_path=self.tasks_dir)
mmlu_output = self._simple_evaluate_with_error_handling(
model=model,
model_args=model_args,
tasks=self.tasks,
num_fewshot=self.few_shots,
batch_size=self.batch_size,
device=self.device,
task_manager=tm,
)
results = mmlu_output["results"]
return results
should_apply_chat_template = self.system_prompt is not None

# configure the args here so users can override them as necessary
simple_evaluate_kwargs = {
"model": model,
"model_args": model_args,
"tasks": self.tasks,
"num_fewshot": self.few_shots,
"batch_size": self.batch_size,
"device": self.device,
"task_manager": tm,
"system_instruction": self.system_prompt,
"apply_chat_template": should_apply_chat_template,
}
simple_evaluate_kwargs.update(extra_args)

results = self._simple_evaluate_with_error_handling(**simple_evaluate_kwargs)
self._results = results
return results["results"]

# This method converts general errors from simple_evaluate
# into a more user-understandable error
Expand Down Expand Up @@ -213,12 +243,13 @@ class MMLUEvaluator(AbstractMMLUEvaluator):
Evaluator for Massive Multitask Language Understanding (MMLU)
Attributes:
model_path absolute path to or name of a huggingface model
tasks list of tasks for MMLU to test the model with
model_dtype dtype of model when served
few_shots number of examples
batch_size batch size for evaluation. Valid values are a positive integer or 'auto' to select the largest batch size that will fit in memory, or 'auto:N' to reselect the largest batch size N times'.
device PyTorch device (e.g. "cpu" or "cuda:0") for running models
model_path absolute path to or name of a huggingface model
tasks list of tasks for MMLU to test the model with
model_dtype dtype of model when served
few_shots number of examples
batch_size batch size for evaluation. Valid values are a positive integer or 'auto' to select the largest batch size that will fit in memory, or 'auto:N' to reselect the largest batch size N times'.
device PyTorch device (e.g. "cpu" or "cuda:0") for running models
system_prompt system prompt to be used when applying the chat template
"""

name = "mmlu"
Expand All @@ -231,9 +262,17 @@ def __init__(
few_shots: int = 5,
batch_size: Optional[Union[int, str]] = "auto",
device: str = ("cuda" if torch.cuda.is_available() else "cpu"),
system_prompt: Optional[str] = None,
) -> None:
super().__init__(
model_path, None, tasks, model_dtype, few_shots, batch_size, device
model_path,
None,
tasks,
model_dtype,
few_shots,
batch_size,
device,
system_prompt=system_prompt,
)


Expand All @@ -243,6 +282,7 @@ class MMLUBranchEvaluator(AbstractMMLUEvaluator):
Attributes:
model_path absolute path to or name of a huggingface model
system_prompt system prompt to be used when applying the chat template
tasks_dir path where the <TASK_NAME>.jsonl and <TASK_NAME>_task.yaml files for the branches being evaluated are stored
tasks group name that is shared by all the MMLUBranch tasks
model_dtype dtype of model when served
Expand Down
11 changes: 9 additions & 2 deletions tests/test_mmlu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,10 @@ def test_mmlu_branch(eval_mock):
tasks_dir = f"{os.path.dirname(os.path.realpath(__file__))}/testdata/sdg"
tasks = ["mmlu_pr"]
mmlu = MMLUBranchEvaluator(
model_path=MODEL_EXAMPLE, tasks_dir=tasks_dir, tasks=tasks
model_path=MODEL_EXAMPLE,
tasks_dir=tasks_dir,
tasks=tasks,
system_prompt="You are an intelligent AI language model.",
)
overall_score, individual_scores = mmlu.run()

Expand All @@ -62,7 +65,11 @@ def test_mmlu_branch(eval_mock):
)
def test_mmlu(eval_mock):
tasks = ["mmlu_anatomy", "mmlu_astronomy", "mmlu_algebra"]
mmlu = MMLUEvaluator(model_path=MODEL_EXAMPLE, tasks=tasks)
mmlu = MMLUEvaluator(
model_path=MODEL_EXAMPLE,
tasks=tasks,
system_prompt="You are an intelligent AI language model.",
)
overall_score, individual_scores = mmlu.run()

eval_mock.assert_called()
Expand Down

0 comments on commit c086116

Please sign in to comment.