diff --git a/src/lib.rs b/src/lib.rs index 752b445..f0b9691 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -299,7 +299,7 @@ impl InferenceGatewayAPI for InferenceGatewayClient { model: &str, messages: Vec, ) -> Result { - let provider_str = provider.to_string().to_lowercase(); // force lowercase - TODO - fix the serialization + let provider_str = provider.to_string(); let url = format!("{}/llms/{}/generate", self.base_url, provider_str); let mut request = self.client.post(&url); if let Some(token) = &self.token { @@ -312,23 +312,24 @@ impl InferenceGatewayAPI for InferenceGatewayClient { }; let response = request.json(&request_payload).send().await?; + match response.status() { StatusCode::OK => Ok(response.json().await?), - StatusCode::UNAUTHORIZED => { - let error: ErrorResponse = response.json().await?; - Err(GatewayError::Unauthorized(error.error)) - } StatusCode::BAD_REQUEST => { let error: ErrorResponse = response.json().await?; Err(GatewayError::BadRequest(error.error)) } + StatusCode::UNAUTHORIZED => { + let error: ErrorResponse = response.json().await?; + Err(GatewayError::Unauthorized(error.error)) + } StatusCode::INTERNAL_SERVER_ERROR => { let error: ErrorResponse = response.json().await?; Err(GatewayError::InternalError(error.error)) } - _ => Err(GatewayError::Other(Box::new(std::io::Error::new( + status => Err(GatewayError::Other(Box::new(std::io::Error::new( std::io::ErrorKind::Other, - format!("Unexpected status code: {}", response.status()), + format!("Unexpected status code: {}", status), )))), } } @@ -349,58 +350,6 @@ mod tests { use super::*; use mockito::{Matcher, Server}; - #[tokio::test] - async fn test_gateway_errors() -> Result<(), GatewayError> { - let mut server: mockito::ServerGuard = Server::new_async().await; - - // Test unauthorized error - let unauthorized_mock = server - .mock("GET", "/llms") - .with_status(401) - .with_header("content-type", "application/json") - .with_body(r#"{"error":"Invalid token"}"#) - .create(); - - let client = InferenceGatewayClient::new(&server.url()); - match client.list_models().await { - Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"), - _ => panic!("Expected Unauthorized error"), - } - unauthorized_mock.assert(); - - // Test bad request error - let bad_request_mock = server - .mock("GET", "/llms") - .with_status(400) - .with_header("content-type", "application/json") - .with_body(r#"{"error":"Invalid provider"}"#) - .create(); - - match client.list_models().await { - Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"), - _ => panic!("Expected BadRequest error"), - } - bad_request_mock.assert(); - - // Test internal server error - let internal_error_mock = server - .mock("GET", "/llms") - .with_status(500) - .with_header("content-type", "application/json") - .with_body(r#"{"error":"Internal server error occurred"}"#) - .create(); - - match client.list_models().await { - Err(GatewayError::InternalError(msg)) => { - assert_eq!(msg, "Internal server error occurred") - } - _ => panic!("Expected InternalError error"), - } - internal_error_mock.assert(); - - Ok(()) - } - #[test] fn test_provider_serialization() { let providers = vec![ @@ -437,6 +386,18 @@ mod tests { } } + #[test] + fn test_provider_case_serialization() { + // Test that Provider::Groq serializes to lowercase + let provider = Provider::Groq; + let json = serde_json::to_string(&provider).unwrap(); + assert_eq!(json, r#""groq""#); + + // Test that uppercase fails to deserialize + let result: Result = serde_json::from_str(r#""Groq""#); + assert!(result.is_err()); + } + #[test] fn test_provider_display() { let providers = vec![ @@ -492,11 +453,16 @@ mod tests { #[tokio::test] async fn test_unauthorized_error() -> Result<(), GatewayError> { let mut server = Server::new_async().await; + + let raw_json_response = r#"{ + "error": "Invalid token" + }"#; + let mock = server .mock("GET", "/llms") .with_status(401) .with_header("content-type", "application/json") - .with_body(r#"{"error":"Invalid token"}"#) + .with_body(raw_json_response) .create(); let client = InferenceGatewayClient::new(&server.url()); @@ -514,11 +480,21 @@ mod tests { #[tokio::test] async fn test_list_models() -> Result<(), GatewayError> { let mut server = Server::new_async().await; + + let raw_response_json = r#"[ + { + "provider": "ollama", + "models": [ + {"name": "llama2"} + ] + } + ]"#; + let mock = server .mock("GET", "/llms") .with_status(200) .with_header("content-type", "application/json") - .with_body(r#"[{"provider":"ollama","models":[{"name":"llama2"}]}]"#) + .with_body(raw_response_json) .create(); let client = InferenceGatewayClient::new(&server.url()); @@ -534,11 +510,19 @@ mod tests { #[tokio::test] async fn test_list_models_by_provider() -> Result<(), GatewayError> { let mut server = Server::new_async().await; + + let raw_json_response = r#"{ + "provider":"ollama", + "models": [{ + "name": "llama2" + }] + }"#; + let mock = server .mock("GET", "/llms/ollama") .with_status(200) .with_header("content-type", "application/json") - .with_body(r#"{"provider":"ollama","models":[{"name":"llama2"}]}"#) + .with_body(raw_json_response) .create(); let client = InferenceGatewayClient::new(&server.url()); @@ -554,11 +538,21 @@ mod tests { #[tokio::test] async fn test_generate_content() -> Result<(), GatewayError> { let mut server = Server::new_async().await; + + let raw_json_response = r#"{ + "provider":"ollama", + "response":{ + "role":"assistant", + "model":"llama2", + "content":"Hellloooo" + } + }"#; + let mock = server .mock("POST", "/llms/ollama/generate") .with_status(200) .with_header("content-type", "application/json") - .with_body(r#"{"provider":"ollama","response":{"role":"assistant","model":"llama2","content":"Hellloooo"}}"#) + .with_body(raw_json_response) .create(); let client = InferenceGatewayClient::new(&server.url()); @@ -579,6 +573,144 @@ mod tests { Ok(()) } + #[tokio::test] + async fn test_generate_content_serialization() -> Result<(), GatewayError> { + let mut server = Server::new_async().await; + + // Raw JSON response from API for debugging + let raw_json = r#"{ + "provider": "groq", + "response": { + "role": "assistant", + "model": "mixtral-8x7b", + "content": "Hello" + } + }"#; + + // Create mock with exact JSON structure + let mock = server + .mock("POST", "/llms/groq/generate") + .with_status(200) + .with_header("content-type", "application/json") + .with_body(raw_json) + .create(); + + let client = InferenceGatewayClient::new(&server.url()); + + // Test direct JSON deserialization first + let direct_parse: Result = serde_json::from_str(raw_json); + assert!( + direct_parse.is_ok(), + "Direct JSON parse failed: {:?}", + direct_parse.err() + ); + + // Test through client + let messages = vec![Message { + role: MessageRole::User, + content: "Hello".to_string(), + }]; + + let response = client + .generate_content(Provider::Groq, "mixtral-8x7b", messages) + .await?; + + // Verify structure matches + assert_eq!(response.provider, Provider::Groq); + assert_eq!(response.response.role, MessageRole::Assistant); + assert_eq!(response.response.model, "mixtral-8x7b"); + assert_eq!(response.response.content, "Hello"); + + mock.assert(); + Ok(()) + } + + #[tokio::test] + async fn test_generate_content_error_response() -> Result<(), GatewayError> { + let mut server = Server::new_async().await; + + let raw_json_response = r#"{ + "error":"Invalid request" + }"#; + + let mock = server + .mock("POST", "/llms/groq/generate") + .with_status(400) + .with_header("content-type", "application/json") + .with_body(raw_json_response) + .create(); + + let client = InferenceGatewayClient::new(&server.url()); + let messages = vec![Message { + role: MessageRole::User, + content: "Hello".to_string(), + }]; + let error = client + .generate_content(Provider::Groq, "mixtral-8x7b", messages) + .await + .unwrap_err(); + + assert!(matches!(error, GatewayError::BadRequest(_))); + if let GatewayError::BadRequest(msg) = error { + assert_eq!(msg, "Invalid request"); + } + mock.assert(); + + Ok(()) + } + + #[tokio::test] + async fn test_gateway_errors() -> Result<(), GatewayError> { + let mut server: mockito::ServerGuard = Server::new_async().await; + + // Test unauthorized error + let unauthorized_mock = server + .mock("GET", "/llms") + .with_status(401) + .with_header("content-type", "application/json") + .with_body(r#"{"error":"Invalid token"}"#) + .create(); + + let client = InferenceGatewayClient::new(&server.url()); + match client.list_models().await { + Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"), + _ => panic!("Expected Unauthorized error"), + } + unauthorized_mock.assert(); + + // Test bad request error + let bad_request_mock = server + .mock("GET", "/llms") + .with_status(400) + .with_header("content-type", "application/json") + .with_body(r#"{"error":"Invalid provider"}"#) + .create(); + + match client.list_models().await { + Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"), + _ => panic!("Expected BadRequest error"), + } + bad_request_mock.assert(); + + // Test internal server error + let internal_error_mock = server + .mock("GET", "/llms") + .with_status(500) + .with_header("content-type", "application/json") + .with_body(r#"{"error":"Internal server error occurred"}"#) + .create(); + + match client.list_models().await { + Err(GatewayError::InternalError(msg)) => { + assert_eq!(msg, "Internal server error occurred") + } + _ => panic!("Expected InternalError error"), + } + internal_error_mock.assert(); + + Ok(()) + } + #[tokio::test] async fn test_health_check() -> Result<(), GatewayError> { let mut server = Server::new_async().await;