Skip to content

Commit

Permalink
Allowing open answer prompting with MultipleChoiceQuestion (#175)
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbraza authored Jan 14, 2025
1 parent 22386b8 commit d46a953
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 12 deletions.
4 changes: 4 additions & 0 deletions src/aviary/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
)
from aviary.utils import (
EvalAnswerMode,
MultipleChoiceEvaluation,
MultipleChoiceQuestion,
encode_image_to_base64,
eval_answer,
extract_answer,
Expand All @@ -62,6 +64,8 @@
"Message",
"Messages",
"MessagesAdapter",
"MultipleChoiceEvaluation",
"MultipleChoiceQuestion",
"Parameters",
"Renderer",
"TaskConfig",
Expand Down
17 changes: 15 additions & 2 deletions src/aviary/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,11 @@ async def extract_answer(
class MultipleChoiceQuestion(BaseModel):
model_config = ConfigDict(extra="forbid")

QUESTION_PROMPT_TEMPLATE: ClassVar[str] = "Q: {question}\n\nOptions:\n{options}"
OPEN_ANSWER_PROMPT_TEMPLATE: ClassVar[str] = "Q: {question}"
MC_QUESTION_PROMPT_TEMPLATE: ClassVar[str] = "\n\n".join((
OPEN_ANSWER_PROMPT_TEMPLATE,
"Options:\n{options}",
))
DEFAULT_UNSURE_OPTION: ClassVar[str] = (
"Insufficient information to answer this question"
)
Expand All @@ -228,6 +232,13 @@ class MultipleChoiceQuestion(BaseModel):
question: str = Field(
description="Question to answer (without multiple choice options)."
)
prompt_without_options: bool = Field(
default=False,
description=(
"Opt-in flag to exclude options from the question_prompt, effectively"
" making the prompt be open answer."
),
)
options: Sequence[str] = Field(description="All multiple choice options.")
ideal_answer: str = Field(
description=(
Expand Down Expand Up @@ -284,7 +295,9 @@ def unsure_answer_index(self) -> int | None:

@property
def question_prompt(self) -> str:
return self.QUESTION_PROMPT_TEMPLATE.format(
if self.prompt_without_options:
return self.OPEN_ANSWER_PROMPT_TEMPLATE.format(question=self.question)
return self.MC_QUESTION_PROMPT_TEMPLATE.format(
question=self.question,
options="\n".join([
f"{_CAPITAL_A_INDEX + i:c}) {o}" for i, o in enumerate(self.options)
Expand Down
56 changes: 46 additions & 10 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@

import pytest

from aviary.core import eval_answer, extract_answer
from aviary.utils import MultipleChoiceEvaluation, MultipleChoiceQuestion
from aviary.core import (
MultipleChoiceEvaluation,
MultipleChoiceQuestion,
eval_answer,
extract_answer,
)
from tests.conftest import VCR_DEFAULT_MATCH_ON


Expand Down Expand Up @@ -82,15 +86,12 @@ def _assert_prompt_is_valid(
question: str,
ideal_answer: str,
distractors: Iterable[str],
is_open_answer: bool = False,
) -> None:
question_prompt = mc_question.question_prompt
for substr in (
question,
"Insufficient information",
ideal_answer,
*distractors,
):
assert question_prompt.count(substr) == 1
assert question_prompt.count(question) == 1
for substr in ("Insufficient information", ideal_answer, *distractors):
assert question_prompt.count(substr) == (1 if not is_open_answer else 0)

# Use for general purpose testing
ZIP_CODE_QUESTION_IDEAL_DISTRACTORS = (
Expand Down Expand Up @@ -232,7 +233,7 @@ def test_consistent_mc_options(self) -> None:

mc_question_1a_copy = MultipleChoiceQuestion(**mc_question_1a.model_dump())
self._assert_prompt_is_valid(mc_question_1a_copy, question, ideal, distractors)
assert mc_question_1a == mc_question_1b, (
assert mc_question_1a == mc_question_1a_copy == mc_question_1b, (
"Serialization then deserialization should lead to same prompts"
)

Expand All @@ -258,6 +259,41 @@ def test_consistent_mc_options(self) -> None:
"Different seeding strategies should lead to different prompts"
)

def test_consistent_open_answer(self) -> None:
question, ideal, distractors = self.MEANING_OF_LIFE_QUESTION_IDEAL_DISTRACTORS
mc_question_1a = MultipleChoiceQuestion(
question=question,
ideal_answer=ideal,
options=distractors,
shuffle_seed=0,
prompt_without_options=True,
)
self._assert_prompt_is_valid(
mc_question_1a, question, ideal, distractors, is_open_answer=True
)

mc_question_1b = MultipleChoiceQuestion(
question=question,
ideal_answer=ideal,
options=distractors,
shuffle_seed=0,
prompt_without_options=True,
)
self._assert_prompt_is_valid(
mc_question_1b, question, ideal, distractors, is_open_answer=True
)
assert mc_question_1a == mc_question_1b, (
"Same seeding should lead to same prompts"
)

mc_question_1a_copy = MultipleChoiceQuestion(**mc_question_1a.model_dump())
self._assert_prompt_is_valid(
mc_question_1a_copy, question, ideal, distractors, is_open_answer=True
)
assert mc_question_1a == mc_question_1a_copy == mc_question_1b, (
"Serialization then deserialization should lead to same prompts"
)


class TestMultipleChoiceEvaluation:
@pytest.mark.parametrize(
Expand Down

0 comments on commit d46a953

Please sign in to comment.