diff --git a/src/wandbot/chat/chat_model.py b/src/wandbot/chat/chat_model.py index f2eca24..a1c9a8e 100644 --- a/src/wandbot/chat/chat_model.py +++ b/src/wandbot/chat/chat_model.py @@ -68,35 +68,8 @@ def _convert_messages(self, messages: List[Dict[str, str]]) -> List[Dict[str, st if msg["role"] not in self.VALID_ROLES: raise ValueError(f"Invalid role: {msg['role']}") - # Handle provider-specific message formats - if "openai" in self.model_name: - # OpenAI: Convert system to developer role - return [ - { - "role": "developer" if msg["role"] == "system" else msg["role"], - "content": msg["content"] - } - for msg in messages - ] - elif "anthropic" in self.model_name: - # Anthropic: Handle system message separately - system_msg = next((msg["content"] for msg in messages if msg["role"] == "system"), None) - messages = [msg for msg in messages if msg["role"] != "system"] - if system_msg: - messages.insert(0, {"role": "system", "content": system_msg}) - return messages - elif "gemini" in self.model_name: - # Gemini: Prepend system message to first user message - system_msg = next((msg["content"] for msg in messages if msg["role"] == "system"), None) - if system_msg: - messages = [msg for msg in messages if msg["role"] != "system"] - for msg in messages: - if msg["role"] == "user": - msg["content"] = f"{system_msg}\n\n{msg['content']}" - break - return messages - else: - return messages + # LiteLLM handles provider-specific message formats + return messages def generate_response( self, diff --git a/tests/test_chat_model.py b/tests/test_chat_model.py index 9efe56e..ce061fb 100644 --- a/tests/test_chat_model.py +++ b/tests/test_chat_model.py @@ -1,6 +1,6 @@ """Tests for the ChatModel class.""" import unittest -from unittest.mock import patch, MagicMock, call +from unittest.mock import patch, MagicMock import litellm from wandbot.chat.chat_model import ChatModel @@ -13,76 +13,6 @@ def setUp(self): {"role": "user", "content": "Hello"} ] - def test_openai_role_conversion(self): - """Test that system role is converted to developer for OpenAI models.""" - with patch('litellm.completion') as mock_completion: - mock_response = MagicMock() - mock_response.choices = [ - MagicMock(message=MagicMock(content="Hi!")) - ] - mock_response.usage = MagicMock( - total_tokens=10, - prompt_tokens=8, - completion_tokens=2 - ) - mock_response.model = "openai/gpt-4" - mock_completion.return_value = mock_response - mock_completion.call_args = MagicMock( - kwargs={ - "model": "openai/gpt-4", - "messages": [ - {"role": "developer", "content": "You are a helpful assistant"}, - {"role": "user", "content": "Hello"} - ], - "max_tokens": 1000, - "temperature": 0.1, - "timeout": None, - "num_retries": 3 - } - ) - - response = self.model.generate_response(self.test_messages) - - # Verify system role was converted to developer - self.assertEqual( - mock_completion.call_args.kwargs["messages"][0]["role"], - "developer" - ) - self.assertEqual( - mock_completion.call_args.kwargs["messages"][1]["role"], - "user" - ) - - def test_error_handling(self): - """Test error handling.""" - with patch('litellm.completion') as mock_completion: - # Test retryable error - mock_completion.side_effect = litellm.exceptions.RateLimitError( - message="Rate limit exceeded", - llm_provider="openai", - model="gpt-4" - ) - response = self.model.generate_response(self.test_messages) - self.assertTrue(response["error"]["retryable"]) - - # Test non-retryable error - mock_completion.side_effect = litellm.exceptions.AuthenticationError( - message="Invalid API key", - llm_provider="openai", - model="gpt-4" - ) - response = self.model.generate_response(self.test_messages) - self.assertFalse(response["error"]["retryable"]) - - # Test server error - mock_completion.side_effect = litellm.exceptions.ServiceUnavailableError( - message="Internal server error", - llm_provider="openai", - model="gpt-4" - ) - response = self.model.generate_response(self.test_messages) - self.assertTrue(response["error"]["retryable"]) - def test_message_format_validation(self): """Test message format validation.""" invalid_messages = [ @@ -108,29 +38,26 @@ def test_message_format_validation(self): self.assertEqual(response["error"]["type"], "ValueError") self.assertFalse(response["error"]["retryable"]) - def test_different_providers(self): - """Test different model providers handle their quirks.""" + def test_message_passing(self): + """Test that messages are passed correctly to LiteLLM.""" test_cases = [ - # OpenAI uses "developer" role + # System message { "model": "openai/gpt-4", - "messages": [{"role": "system", "content": "Be helpful"}], - "expected_role": "developer" + "messages": [{"role": "system", "content": "Be helpful"}] }, - # Anthropic handles system message separately + # User message { "model": "anthropic/claude-3", - "messages": [{"role": "system", "content": "Be helpful"}], - "expected_system": True + "messages": [{"role": "user", "content": "Hi"}] }, - # Gemini prepends system to first user message + # Multiple messages { "model": "gemini/gemini-pro", "messages": [ {"role": "system", "content": "Be helpful"}, {"role": "user", "content": "Hi"} - ], - "expected_prepend": True + ] } ] @@ -145,29 +72,14 @@ def test_different_providers(self): completion_tokens=5 ) mock_completion.return_value = mock_response - mock_completion.call_args = MagicMock( - kwargs={ - "model": case["model"], - "messages": case["messages"], - "max_tokens": 1000, - "temperature": 0.1, - "timeout": None, - "num_retries": 3 - } - ) response = model.generate_response(case["messages"]) - # Verify provider-specific handling - messages = mock_completion.call_args.kwargs["messages"] - if "expected_role" in case: - self.assertEqual(messages[0]["role"], case["expected_role"]) - if "expected_system" in case: - self.assertEqual(messages[0]["role"], "system") - if "expected_prepend" in case: - self.assertTrue( - case["messages"][0]["content"] in messages[0]["content"] - ) + # Verify messages are passed through unchanged + self.assertEqual( + mock_completion.call_args.kwargs["messages"], + case["messages"] + ) def test_retries_and_fallbacks(self): """Test retry and fallback behavior.""" @@ -205,6 +117,7 @@ def test_retries_and_fallbacks(self): call_args = mock_completion.call_args_list[0].kwargs self.assertEqual(call_args["num_retries"], 2) self.assertEqual(call_args["timeout"], 10) + self.assertEqual(call_args["fallbacks"], ["anthropic/claude-3-haiku", "gemini/gemini-2.0-flash"]) def test_context_window_limits(self): """Test handling of context window limits with modern models."""