-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add MockLLM and MockLLMResponse classes to allow users easier t…
…esting for their own workflows
- Loading branch information
Showing
3 changed files
with
229 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import re | ||
from typing import Dict, List, Optional, Pattern, Type, Union | ||
|
||
from pydantic import BaseModel | ||
|
||
from ..llm_interface import LLMInterface, NoCache | ||
|
||
|
||
class MockLLMResponse: | ||
"""Configuration for a mock response.""" | ||
|
||
def __init__( | ||
self, | ||
pattern: Union[str, Pattern], | ||
response: Optional[BaseModel] = None, | ||
raise_exception: bool = False, | ||
exception: Optional[Exception] = None, | ||
): | ||
""" | ||
Args: | ||
pattern: Regex pattern to match against the full prompt | ||
response: Pydantic object to return when pattern matches | ||
raise_exception: If True, raise an exception instead of returning response | ||
exception: Specific exception to raise (defaults to ValueError if not specified) | ||
""" | ||
self.pattern = re.compile(pattern) if isinstance(pattern, str) else pattern | ||
self.response = response | ||
self.raise_exception = raise_exception | ||
self.exception = exception or ValueError("Mock LLM error") | ||
|
||
|
||
class MockLLM(LLMInterface): | ||
"""A mock LLM implementation for testing purposes.""" | ||
|
||
def __init__( | ||
self, | ||
responses: List[MockLLMResponse], | ||
support_structured_outputs: bool = True, | ||
): | ||
""" | ||
Args: | ||
responses: List of MockLLMResponse objects defining pattern-response pairs | ||
support_structured_outputs: Whether to return structured outputs directly | ||
""" | ||
super().__init__( | ||
model_name="mock", | ||
client=None, | ||
support_structured_outputs=support_structured_outputs, | ||
use_cache=False, | ||
) | ||
self.responses = responses | ||
self.disk_cache = NoCache() | ||
|
||
def _cached_chat( | ||
self, | ||
messages: List[Dict[str, str]], | ||
tools: Optional[List] = None, | ||
temperature: Optional[float] = None, | ||
response_schema: Optional[Type[BaseModel]] = None, | ||
) -> str: | ||
# Reconstruct the full prompt from messages | ||
full_prompt = "" | ||
for msg in messages: | ||
if msg["role"] == "system": | ||
full_prompt += f"System: {msg['content']}\n" | ||
elif msg["role"] == "user": | ||
full_prompt += f"User: {msg['content']}\n" | ||
|
||
# Try to match the prompt against our patterns | ||
for mock_response in self.responses: | ||
if mock_response.pattern.search(full_prompt): | ||
if mock_response.raise_exception: | ||
raise mock_response.exception | ||
|
||
if mock_response.response is None: | ||
return None | ||
|
||
if self.support_structured_outputs: | ||
return mock_response.response | ||
else: | ||
# Convert Pydantic object to JSON string | ||
return mock_response.response.model_dump_json() | ||
|
||
raise ValueError(f"No matching mock response found for prompt: {full_prompt}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,145 @@ | ||
import unittest | ||
from typing import List, Optional | ||
from pydantic import BaseModel | ||
|
||
from planai.testing.mock_llm import MockLLM, MockLLMResponse | ||
|
||
|
||
# Test Pydantic models | ||
class Person(BaseModel): | ||
name: str | ||
age: int | ||
hobbies: Optional[List[str]] = None | ||
|
||
|
||
class Address(BaseModel): | ||
street: str | ||
city: str | ||
country: str | ||
|
||
|
||
class TestMockLLM(unittest.TestCase): | ||
def setUp(self): | ||
# Create some sample mock responses | ||
self.responses = [ | ||
MockLLMResponse( | ||
pattern=r"Get person named Alice", | ||
response=Person(name="Alice", age=30, hobbies=["reading", "hiking"]), | ||
), | ||
MockLLMResponse( | ||
pattern=r"Get person named Bob", | ||
response=Person(name="Bob", age=25, hobbies=["gaming"]), | ||
), | ||
MockLLMResponse( | ||
pattern=r"Get invalid person", | ||
raise_exception=True, | ||
exception=ValueError("Invalid person request"), | ||
), | ||
MockLLMResponse(pattern=r"Get none result", response=None), | ||
MockLLMResponse( | ||
pattern=r"Get address.*", | ||
response=Address( | ||
street="123 Main St", city="Springfield", country="USA" | ||
), | ||
), | ||
] | ||
self.mock_llm = MockLLM(responses=self.responses) | ||
|
||
def test_successful_response(self): | ||
"""Test successful response matching.""" | ||
result = self.mock_llm.generate_pydantic( | ||
prompt_template="Get person named Alice", | ||
output_schema=Person, | ||
system="Test system prompt", | ||
) | ||
|
||
self.assertIsInstance(result, Person) | ||
self.assertEqual(result.name, "Alice") | ||
self.assertEqual(result.age, 30) | ||
self.assertEqual(result.hobbies, ["reading", "hiking"]) | ||
|
||
def test_multiple_patterns(self): | ||
"""Test multiple different patterns work correctly.""" | ||
result1 = self.mock_llm.generate_pydantic( | ||
prompt_template="Get person named Alice", output_schema=Person | ||
) | ||
result2 = self.mock_llm.generate_pydantic( | ||
prompt_template="Get person named Bob", output_schema=Person | ||
) | ||
|
||
self.assertEqual(result1.name, "Alice") | ||
self.assertEqual(result2.name, "Bob") | ||
|
||
def test_exception_raising(self): | ||
"""Test that exceptions are raised correctly.""" | ||
with self.assertRaises(ValueError) as context: | ||
self.mock_llm.generate_pydantic( | ||
prompt_template="Get invalid person", output_schema=Person | ||
) | ||
|
||
self.assertEqual(str(context.exception), "Invalid person request") | ||
|
||
def test_none_response(self): | ||
"""Test that None responses are handled correctly.""" | ||
result = self.mock_llm.generate_pydantic( | ||
prompt_template="Get none result", output_schema=Person | ||
) | ||
|
||
self.assertIsNone(result) | ||
|
||
def test_no_matching_pattern(self): | ||
"""Test behavior when no pattern matches.""" | ||
with self.assertRaises(ValueError) as context: | ||
self.mock_llm.generate_pydantic( | ||
prompt_template="This won't match any pattern", output_schema=Person | ||
) | ||
|
||
self.assertIn("No matching mock response found", str(context.exception)) | ||
|
||
def test_regex_pattern_matching(self): | ||
"""Test that regex patterns work correctly.""" | ||
result = self.mock_llm.generate_pydantic( | ||
prompt_template="Get address for John Doe", output_schema=Address | ||
) | ||
|
||
self.assertIsInstance(result, Address) | ||
self.assertEqual(result.street, "123 Main St") | ||
self.assertEqual(result.city, "Springfield") | ||
|
||
def test_different_output_schemas(self): | ||
"""Test that different output schemas work correctly.""" | ||
person_result = self.mock_llm.generate_pydantic( | ||
prompt_template="Get person named Alice", output_schema=Person | ||
) | ||
address_result = self.mock_llm.generate_pydantic( | ||
prompt_template="Get address for someone", output_schema=Address | ||
) | ||
|
||
self.assertIsInstance(person_result, Person) | ||
self.assertIsInstance(address_result, Address) | ||
|
||
def test_with_system_prompt(self): | ||
"""Test that system prompts are included in pattern matching.""" | ||
result = self.mock_llm.generate_pydantic( | ||
prompt_template="Get person named Alice", | ||
system="This is a test system prompt", | ||
output_schema=Person, | ||
) | ||
|
||
self.assertIsInstance(result, Person) | ||
self.assertEqual(result.name, "Alice") | ||
|
||
def test_json_mode(self): | ||
"""Test MockLLM with JSON mode (non-structured outputs).""" | ||
mock_llm = MockLLM(responses=self.responses, support_structured_outputs=False) | ||
|
||
result = mock_llm.generate_pydantic( | ||
prompt_template="Get person named Alice", output_schema=Person | ||
) | ||
|
||
self.assertIsInstance(result, Person) | ||
self.assertEqual(result.name, "Alice") | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |