Skip to content

Commit

Permalink
feat: add support for o1-mini and o1-preview which do not support jso…
Browse files Browse the repository at this point in the history
…n mode nor structured outputs nor system prompts
  • Loading branch information
provos committed Oct 30, 2024
1 parent 1ad3462 commit c242067
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 4 deletions.
4 changes: 4 additions & 0 deletions src/planai/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,15 @@ def llm_from_config(
"gpt-4o-2024-08-06",
"gpt-4o",
]
support_json_mode = model_name not in ["o1-mini", "o1-preview"]
support_system_prompt = model_name not in ["o1-mini", "o1-preview"]
return LLMInterface(
model_name=model_name,
log_dir=log_dir,
client=wrapper,
support_json_mode=support_json_mode,
support_structured_outputs=support_structured_outputs,
support_system_prompt=support_system_prompt,
)
case "anthropic":
api_key = os.getenv("ANTHROPIC_API_KEY")
Expand Down
14 changes: 10 additions & 4 deletions src/planai/llm_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,13 @@ def __init__(
host: Optional[str] = None,
support_json_mode: bool = True,
support_structured_outputs: bool = False,
support_system_prompt: bool = True,
):
self.model_name = model_name
self.client = client if client else Client(host=host)
self.support_json_mode = support_json_mode
self.support_structured_outputs = support_structured_outputs
self.support_system_prompt = support_system_prompt

self.logger = setup_logging(
logs_dir=log_dir, logs_prefix="llm_interface", logger_name=__name__
Expand Down Expand Up @@ -211,10 +213,10 @@ def generate_pydantic(
if logger:
logger.info("Generated prompt: %s", formatted_prompt)

messages = [
{"role": "system", "content": system},
{"role": "user", "content": formatted_prompt},
]
messages = []
if self.support_system_prompt:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": formatted_prompt})

iteration = 0
while iteration < 3:
Expand Down Expand Up @@ -255,6 +257,10 @@ def generate_pydantic(
if self.support_structured_outputs:
# the raw response was a pydantic object, so we need to dump it to a string
raw_response = raw_response.model_dump_json()
elif not isinstance(raw_response, str):
raise ValueError(
"The response should be a string if the model does not support structured outputs."
)
messages.extend(
[
{"role": "assistant", "content": raw_response},
Expand Down
72 changes: 72 additions & 0 deletions tests/planai/test_llm_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,78 @@ class StructuredOutputModel(BaseModel):
self.assertEqual(call_args["messages"], expected_messages)
self.assertEqual(call_args["response_schema"], StructuredOutputModel)

def test_chat_without_system_prompt(self):
# Create a dummy Pydantic model as output schema
class StructuredOutputModel(BaseModel):
field1: str
field2: int

# Mock the structured response from the chat method
structured_response = StructuredOutputModel(field1="direct", field2=123)
self.llm_interface.support_structured_outputs = True
self.llm_interface.support_system_prompt = (
False # Simulate model without system prompt support
)

self.mock_client.chat.return_value = {
"message": {"content": structured_response}
}

# Performing the test
response = self.llm_interface.generate_pydantic(
prompt_template="Dummy prompt",
output_schema=StructuredOutputModel,
system=self.system,
)

# Assertions to ensure the response is directly the structured output
self.assertEqual(response, structured_response)

# Ensure chat was called once with expected messages
self.mock_client.chat.assert_called_once()

# Check the message format
expected_messages = [
{"role": "user", "content": "Dummy prompt"},
]
call_args = self.mock_client.chat.call_args[1]
self.assertEqual(call_args["messages"], expected_messages)
self.assertEqual(call_args["response_schema"], StructuredOutputModel)

def test_generate_pydantic_without_json_mode(self):
# Create a dummy Pydantic model as output schema
class StructuredOutputModel(BaseModel):
field1: str
field2: int

# Mock the response from the chat method
raw_response = '{"field1": "direct", "field2": 123}'
stripped_response = '{"field1": "direct", "field2": 123}' # Assuming the stripped response is the same for simplicity
self.llm_interface.support_json_mode = (
False # Simulate model without JSON mode support
)

self.mock_client.chat.return_value = {"message": {"content": raw_response}}

with patch.object(
self.llm_interface,
"_strip_text_from_json_response",
return_value=stripped_response,
) as mock_strip:
# Performing the test
response = self.llm_interface.generate_pydantic(
prompt_template="Dummy prompt",
output_schema=StructuredOutputModel,
system=self.system,
)

# Assertions to ensure the response is correctly parsed
expected_response = StructuredOutputModel(field1="direct", field2=123)
self.assertEqual(response, expected_response)

# Ensure _strip_text_from_json_response was called
mock_strip.assert_called_once_with(raw_response)


if __name__ == "__main__":
unittest.main()

0 comments on commit c242067

Please sign in to comment.