diff --git a/README.md b/README.md index 7c8d667..51023ba 100644 --- a/README.md +++ b/README.md @@ -21,7 +21,7 @@ scripts/infra/cloud-instance.sh ec2 ssh # Regardless of how you setup your instance git clone https://github.com/instructlab/taxonomy.git && pushd taxonomy && git branch rc && popd git clone --bare https://github.com/instructlab/eval.git && git clone eval.git/ && cd eval && git remote add syncrepo ../eval.git -python -m venv venv +python3 -m venv venv source venv/bin/activate pip install -r requirements.txt pip install -r requirements-dev.txt diff --git a/pyproject.toml b/pyproject.toml index b11c7bd..03faef9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -96,3 +96,6 @@ from-first = true # import-heading-firstparty=First Party # import-heading-localfolder=Local known-local-folder = ["tuning"] + +[tool.mypy] +ignore_missing_imports = true diff --git a/src/instructlab/eval/exceptions.py b/src/instructlab/eval/exceptions.py index 1037d90..b3c7ca5 100644 --- a/src/instructlab/eval/exceptions.py +++ b/src/instructlab/eval/exceptions.py @@ -107,3 +107,15 @@ def __init__(self, tasks_dir) -> None: super().__init__() self.tasks_dir = tasks_dir self.message = f"Invalid Tasks Dir: {tasks_dir}" + + +class ModelServingAPIError(EvalError): + """ + Error raised when reply retrieval from model serving fails. + Attributes + message error message to be printed on raise + """ + + def __init__(self) -> None: + super().__init__() + self.message = "Failed to receive a reply from model serving API." diff --git a/src/instructlab/eval/mt_bench_common.py b/src/instructlab/eval/mt_bench_common.py index 86f9920..ee23546 100644 --- a/src/instructlab/eval/mt_bench_common.py +++ b/src/instructlab/eval/mt_bench_common.py @@ -4,7 +4,7 @@ """ # Standard -from typing import Optional +from typing import Optional, TypedDict import ast import dataclasses import glob @@ -14,9 +14,13 @@ import time # Third Party +from fastchat import conversation from fastchat.model.model_adapter import get_conversation_template # type: ignore import openai +# First Party +from instructlab.eval import exceptions + # Local from .logger_config import setup_logger @@ -247,21 +251,56 @@ def play_a_match_single( return result +def _is_fatal_openai_error(e: openai.OpenAIError) -> bool: + return isinstance( + e, + ( + openai.APIConnectionError, + openai.AuthenticationError, + openai.PermissionDeniedError, + openai.NotFoundError, + ), + ) + + +# TODO: copied from instructlab (cli) utils module; consolidate somewhere? +class Message(TypedDict): + """ + Represents a message within an AI conversation. + """ + + content: str + # one of: "user", "assistant", or "system" + role: str + + +def _get_messages( + conv: conversation.Conversation, merge_system_user_message: bool +) -> list[Message]: + messages = conv.to_openai_api_messages() + if ( + (merge_system_user_message or conv.name == "mistral") + and messages[0]["role"] == "system" + and messages[1]["role"] == "user" + ): + messages[1]["content"] = messages[0]["content"] + "\n" + messages[1]["content"] + return messages[1:] + return messages + + def chat_completion_openai( - openai_client, model, conv, temperature, max_tokens, merge_system_user_message=False + openai_client, + model, + conv: conversation.Conversation, + temperature, + max_tokens, + merge_system_user_message: bool = False, ) -> str: + output = None + messages = _get_messages(conv, merge_system_user_message) + for i in range(API_MAX_RETRY): try: - messages = conv.to_openai_api_messages() - if ( - (merge_system_user_message or conv.name == "mistral") - and messages[0]["role"] == "system" - and messages[1]["role"] == "user" - ): - messages[1]["content"] = ( - messages[0]["content"] + "\n" + messages[1]["content"] - ) - messages = messages[1:] response = openai_client.chat.completions.create( model=model, messages=messages, @@ -269,15 +308,42 @@ def chat_completion_openai( temperature=temperature, max_tokens=max_tokens, ) - return response.choices[0].message.content - except openai.OpenAIError as e: + output = response.choices[0].message.content + break + except ( + # retry won't fix these errors + openai.BadRequestError, # 400 + openai.UnprocessableEntityError, # 422 + ) as e: + logger.debug(e) + return API_ERROR_OUTPUT # immediately soft fail + except ( + # retry may help with these errors + openai.APIConnectionError, + openai.RateLimitError, # 429 + openai.InternalServerError, # >=500 + # NOTE: Errors listed below may need a revisit: we are not sure if + # it's ever helpful to retry them. Leaving them intact for now. + openai.AuthenticationError, # 401 + openai.PermissionDeniedError, # 403 + openai.NotFoundError, # 404 + # General catch-all + openai.OpenAIError, + ) as e: + if not _is_fatal_openai_error(e): + output = API_ERROR_OUTPUT # disable hard fail (never raise!) + # still, retry in the hope we'll get a successful reply if i == API_MAX_RETRY - 1: logger.error(e) break logger.debug(e) time.sleep(API_RETRY_SLEEP) - return API_ERROR_OUTPUT + if output is None: + # not a single attempt was non-fatal; this is indicative of + # basic connectivity or server issue -> hard fail + raise exceptions.ModelServingAPIError + return output def check_data(questions, model_answers, ref_answers, models, judges):