Skip to content

Commit

Permalink
feat: add MockLLM and MockLLMResponse classes to allow users easier t…
Browse files Browse the repository at this point in the history
…esting for their own workflows
  • Loading branch information
provos committed Dec 10, 2024
1 parent ef1d793 commit d4fb18d
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 0 deletions.
Empty file added src/planai/testing/__init__.py
Empty file.
84 changes: 84 additions & 0 deletions src/planai/testing/mock_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import re
from typing import Dict, List, Optional, Pattern, Type, Union

from pydantic import BaseModel

from ..llm_interface import LLMInterface, NoCache


class MockLLMResponse:
"""Configuration for a mock response."""

def __init__(
self,
pattern: Union[str, Pattern],
response: Optional[BaseModel] = None,
raise_exception: bool = False,
exception: Optional[Exception] = None,
):
"""
Args:
pattern: Regex pattern to match against the full prompt
response: Pydantic object to return when pattern matches
raise_exception: If True, raise an exception instead of returning response
exception: Specific exception to raise (defaults to ValueError if not specified)
"""
self.pattern = re.compile(pattern) if isinstance(pattern, str) else pattern
self.response = response
self.raise_exception = raise_exception
self.exception = exception or ValueError("Mock LLM error")


class MockLLM(LLMInterface):
"""A mock LLM implementation for testing purposes."""

def __init__(
self,
responses: List[MockLLMResponse],
support_structured_outputs: bool = True,
):
"""
Args:
responses: List of MockLLMResponse objects defining pattern-response pairs
support_structured_outputs: Whether to return structured outputs directly
"""
super().__init__(
model_name="mock",
client=None,
support_structured_outputs=support_structured_outputs,
use_cache=False,
)
self.responses = responses
self.disk_cache = NoCache()

def _cached_chat(
self,
messages: List[Dict[str, str]],
tools: Optional[List] = None,
temperature: Optional[float] = None,
response_schema: Optional[Type[BaseModel]] = None,
) -> str:
# Reconstruct the full prompt from messages
full_prompt = ""
for msg in messages:
if msg["role"] == "system":
full_prompt += f"System: {msg['content']}\n"
elif msg["role"] == "user":
full_prompt += f"User: {msg['content']}\n"

# Try to match the prompt against our patterns
for mock_response in self.responses:
if mock_response.pattern.search(full_prompt):
if mock_response.raise_exception:
raise mock_response.exception

if mock_response.response is None:
return None

if self.support_structured_outputs:
return mock_response.response
else:
# Convert Pydantic object to JSON string
return mock_response.response.model_dump_json()

raise ValueError(f"No matching mock response found for prompt: {full_prompt}")
145 changes: 145 additions & 0 deletions tests/planai/testing/test_mock_llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import unittest
from typing import List, Optional
from pydantic import BaseModel

from planai.testing.mock_llm import MockLLM, MockLLMResponse


# Test Pydantic models
class Person(BaseModel):
name: str
age: int
hobbies: Optional[List[str]] = None


class Address(BaseModel):
street: str
city: str
country: str


class TestMockLLM(unittest.TestCase):
def setUp(self):
# Create some sample mock responses
self.responses = [
MockLLMResponse(
pattern=r"Get person named Alice",
response=Person(name="Alice", age=30, hobbies=["reading", "hiking"]),
),
MockLLMResponse(
pattern=r"Get person named Bob",
response=Person(name="Bob", age=25, hobbies=["gaming"]),
),
MockLLMResponse(
pattern=r"Get invalid person",
raise_exception=True,
exception=ValueError("Invalid person request"),
),
MockLLMResponse(pattern=r"Get none result", response=None),
MockLLMResponse(
pattern=r"Get address.*",
response=Address(
street="123 Main St", city="Springfield", country="USA"
),
),
]
self.mock_llm = MockLLM(responses=self.responses)

def test_successful_response(self):
"""Test successful response matching."""
result = self.mock_llm.generate_pydantic(
prompt_template="Get person named Alice",
output_schema=Person,
system="Test system prompt",
)

self.assertIsInstance(result, Person)
self.assertEqual(result.name, "Alice")
self.assertEqual(result.age, 30)
self.assertEqual(result.hobbies, ["reading", "hiking"])

def test_multiple_patterns(self):
"""Test multiple different patterns work correctly."""
result1 = self.mock_llm.generate_pydantic(
prompt_template="Get person named Alice", output_schema=Person
)
result2 = self.mock_llm.generate_pydantic(
prompt_template="Get person named Bob", output_schema=Person
)

self.assertEqual(result1.name, "Alice")
self.assertEqual(result2.name, "Bob")

def test_exception_raising(self):
"""Test that exceptions are raised correctly."""
with self.assertRaises(ValueError) as context:
self.mock_llm.generate_pydantic(
prompt_template="Get invalid person", output_schema=Person
)

self.assertEqual(str(context.exception), "Invalid person request")

def test_none_response(self):
"""Test that None responses are handled correctly."""
result = self.mock_llm.generate_pydantic(
prompt_template="Get none result", output_schema=Person
)

self.assertIsNone(result)

def test_no_matching_pattern(self):
"""Test behavior when no pattern matches."""
with self.assertRaises(ValueError) as context:
self.mock_llm.generate_pydantic(
prompt_template="This won't match any pattern", output_schema=Person
)

self.assertIn("No matching mock response found", str(context.exception))

def test_regex_pattern_matching(self):
"""Test that regex patterns work correctly."""
result = self.mock_llm.generate_pydantic(
prompt_template="Get address for John Doe", output_schema=Address
)

self.assertIsInstance(result, Address)
self.assertEqual(result.street, "123 Main St")
self.assertEqual(result.city, "Springfield")

def test_different_output_schemas(self):
"""Test that different output schemas work correctly."""
person_result = self.mock_llm.generate_pydantic(
prompt_template="Get person named Alice", output_schema=Person
)
address_result = self.mock_llm.generate_pydantic(
prompt_template="Get address for someone", output_schema=Address
)

self.assertIsInstance(person_result, Person)
self.assertIsInstance(address_result, Address)

def test_with_system_prompt(self):
"""Test that system prompts are included in pattern matching."""
result = self.mock_llm.generate_pydantic(
prompt_template="Get person named Alice",
system="This is a test system prompt",
output_schema=Person,
)

self.assertIsInstance(result, Person)
self.assertEqual(result.name, "Alice")

def test_json_mode(self):
"""Test MockLLM with JSON mode (non-structured outputs)."""
mock_llm = MockLLM(responses=self.responses, support_structured_outputs=False)

result = mock_llm.generate_pydantic(
prompt_template="Get person named Alice", output_schema=Person
)

self.assertIsInstance(result, Person)
self.assertEqual(result.name, "Alice")


if __name__ == "__main__":
unittest.main()

0 comments on commit d4fb18d

Please sign in to comment.