diff --git a/src/planai/llm_interface.py b/src/planai/llm_interface.py index 9d678eb..5f45b01 100644 --- a/src/planai/llm_interface.py +++ b/src/planai/llm_interface.py @@ -116,32 +116,6 @@ def chat( ) return response.strip() if isinstance(response, str) else response - def _cached_generate(self, prompt: str, system: str = "", format: str = "") -> str: - # Hash the prompt to use as the cache key - prompt_hash = self._generate_hash( - self.model_name + "\n" + system + "\n" + prompt - ) - - # Check if prompt response is in cache - response = self.disk_cache.get(prompt_hash) - - if response is None: - # If not in cache, make request to client - response = self.client.generate( - model=self.model_name, prompt=prompt, system=system, format=format - ) - - # Cache the response with hashed prompt as key - self.disk_cache.set(prompt_hash, response) - - return response - - def generate(self, prompt: str, system: str = "") -> str: - self.logger.info("Generating text with prompt: %s...", prompt[:850]) - response = self._cached_generate(prompt=prompt, system=system) - self.logger.info("Generated text: %s...", response["response"][:850]) - return response["response"].strip() - def _strip_text_from_json_response(self, response: str) -> str: pattern = r"^[^{\[]*([{\[].*[}\]])[^}\]]*$" match = re.search(pattern, response, re.DOTALL) diff --git a/tests/planai/test_llm_interface.py b/tests/planai/test_llm_interface.py index 3acf210..9f6318c 100644 --- a/tests/planai/test_llm_interface.py +++ b/tests/planai/test_llm_interface.py @@ -38,52 +38,6 @@ def setUp(self): self.response_content = "Paris" self.response_data = {"message": {"content": self.response_content}} - def test_generate_with_cache_miss(self): - self.mock_client.generate.return_value = {"response": self.response_content} - - # Call generate - response = self.llm_interface.generate(prompt=self.prompt, system=self.system) - - self.mock_client.generate.assert_called_once_with( - model=self.llm_interface.model_name, - prompt=self.prompt, - system=self.system, - format="", - ) - # Since we changed to use self.response_content directly - self.assertEqual(response, self.response_content) - - def test_generate_with_cache_hit(self): - prompt_hash = self.llm_interface._generate_hash( - self.llm_interface.model_name + "\n" + self.system + "\n" + self.prompt - ) - self.llm_interface.disk_cache.set( - prompt_hash, {"response": self.response_content} - ) - - # Call generate - response = self.llm_interface.generate(prompt=self.prompt, system=self.system) - - # Since it's a cache hit, no chat call should happen - self.mock_client.generate.assert_not_called() - - # Confirming expected parsing - self.assertEqual(response, self.response_content) - - def test_generate_invalid_json_response(self): - # Simulate invalid JSON response - invalid_json_response = {"response": "Not a JSON {...."} - self.mock_client.generate.return_value = invalid_json_response - - with patch("planai.llm_interface.logging.Logger") as mock_logger: - self.llm_interface.logger = mock_logger - response = self.llm_interface.generate( - prompt=self.prompt, system=self.system - ) - - # Expecting the invalid content since there's no parsing - self.assertEqual(response, "Not a JSON {....") - def test_generate_pydantic_valid_response(self): output_model = DummyPydanticModel(field1="test", field2=42) valid_json_response = '{"field1": "test", "field2": 42}' @@ -114,17 +68,6 @@ def test_generate_pydantic_invalid_response(self): self.assertIsNone(response) # Expecting None due to parsing error - def test_cached_generate_caching_mechanism(self): - # First call should miss cache and make client call - self.mock_client.generate.return_value = self.response_data - response = self.llm_interface._cached_generate(self.prompt, self.system) - self.assertEqual(response, self.response_data) - - # Second call should hit cache, no additional client call - response = self.llm_interface._cached_generate(self.prompt, self.system) - self.mock_client.generate.assert_called_once() # Still called only once - self.assertEqual(response, self.response_data) - def test_generate_pydantic_with_retry_logic_and_prompt_check(self): # Simulate an invalid JSON response that fails to parse initially invalid_content = '{"field1": "test"}'