Skip to content

Commit

Permalink
test: Remove unnecessary lowercase conversion for provider string and…
Browse files Browse the repository at this point in the history
… enhance error handling in API responses

Signed-off-by: Eden Reich <[email protected]>
  • Loading branch information
edenreich committed Jan 29, 2025
1 parent 41e9d94 commit cdeb27c
Showing 1 changed file with 195 additions and 63 deletions.
258 changes: 195 additions & 63 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,7 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
model: &str,
messages: Vec<Message>,
) -> Result<GenerateResponse, GatewayError> {
let provider_str = provider.to_string().to_lowercase(); // force lowercase - TODO - fix the serialization
let provider_str = provider.to_string();
let url = format!("{}/llms/{}/generate", self.base_url, provider_str);
let mut request = self.client.post(&url);
if let Some(token) = &self.token {
Expand All @@ -312,23 +312,24 @@ impl InferenceGatewayAPI for InferenceGatewayClient {
};

let response = request.json(&request_payload).send().await?;

match response.status() {
StatusCode::OK => Ok(response.json().await?),
StatusCode::UNAUTHORIZED => {
let error: ErrorResponse = response.json().await?;
Err(GatewayError::Unauthorized(error.error))
}
StatusCode::BAD_REQUEST => {
let error: ErrorResponse = response.json().await?;
Err(GatewayError::BadRequest(error.error))
}
StatusCode::UNAUTHORIZED => {
let error: ErrorResponse = response.json().await?;
Err(GatewayError::Unauthorized(error.error))
}
StatusCode::INTERNAL_SERVER_ERROR => {
let error: ErrorResponse = response.json().await?;
Err(GatewayError::InternalError(error.error))
}
_ => Err(GatewayError::Other(Box::new(std::io::Error::new(
status => Err(GatewayError::Other(Box::new(std::io::Error::new(
std::io::ErrorKind::Other,
format!("Unexpected status code: {}", response.status()),
format!("Unexpected status code: {}", status),
)))),
}
}
Expand All @@ -349,58 +350,6 @@ mod tests {
use super::*;
use mockito::{Matcher, Server};

#[tokio::test]
async fn test_gateway_errors() -> Result<(), GatewayError> {
let mut server: mockito::ServerGuard = Server::new_async().await;

// Test unauthorized error
let unauthorized_mock = server
.mock("GET", "/llms")
.with_status(401)
.with_header("content-type", "application/json")
.with_body(r#"{"error":"Invalid token"}"#)
.create();

let client = InferenceGatewayClient::new(&server.url());
match client.list_models().await {
Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
_ => panic!("Expected Unauthorized error"),
}
unauthorized_mock.assert();

// Test bad request error
let bad_request_mock = server
.mock("GET", "/llms")
.with_status(400)
.with_header("content-type", "application/json")
.with_body(r#"{"error":"Invalid provider"}"#)
.create();

match client.list_models().await {
Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
_ => panic!("Expected BadRequest error"),
}
bad_request_mock.assert();

// Test internal server error
let internal_error_mock = server
.mock("GET", "/llms")
.with_status(500)
.with_header("content-type", "application/json")
.with_body(r#"{"error":"Internal server error occurred"}"#)
.create();

match client.list_models().await {
Err(GatewayError::InternalError(msg)) => {
assert_eq!(msg, "Internal server error occurred")
}
_ => panic!("Expected InternalError error"),
}
internal_error_mock.assert();

Ok(())
}

#[test]
fn test_provider_serialization() {
let providers = vec![
Expand Down Expand Up @@ -437,6 +386,18 @@ mod tests {
}
}

#[test]
fn test_provider_case_serialization() {
// Test that Provider::Groq serializes to lowercase
let provider = Provider::Groq;
let json = serde_json::to_string(&provider).unwrap();
assert_eq!(json, r#""groq""#);

// Test that uppercase fails to deserialize
let result: Result<Provider, _> = serde_json::from_str(r#""Groq""#);
assert!(result.is_err());
}

#[test]
fn test_provider_display() {
let providers = vec![
Expand Down Expand Up @@ -492,11 +453,16 @@ mod tests {
#[tokio::test]
async fn test_unauthorized_error() -> Result<(), GatewayError> {
let mut server = Server::new_async().await;

let raw_json_response = r#"{
"error": "Invalid token"
}"#;

let mock = server
.mock("GET", "/llms")
.with_status(401)
.with_header("content-type", "application/json")
.with_body(r#"{"error":"Invalid token"}"#)
.with_body(raw_json_response)
.create();

let client = InferenceGatewayClient::new(&server.url());
Expand All @@ -514,11 +480,21 @@ mod tests {
#[tokio::test]
async fn test_list_models() -> Result<(), GatewayError> {
let mut server = Server::new_async().await;

let raw_response_json = r#"[
{
"provider": "ollama",
"models": [
{"name": "llama2"}
]
}
]"#;

let mock = server
.mock("GET", "/llms")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"[{"provider":"ollama","models":[{"name":"llama2"}]}]"#)
.with_body(raw_response_json)
.create();

let client = InferenceGatewayClient::new(&server.url());
Expand All @@ -534,11 +510,19 @@ mod tests {
#[tokio::test]
async fn test_list_models_by_provider() -> Result<(), GatewayError> {
let mut server = Server::new_async().await;

let raw_json_response = r#"{
"provider":"ollama",
"models": [{
"name": "llama2"
}]
}"#;

let mock = server
.mock("GET", "/llms/ollama")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{"provider":"ollama","models":[{"name":"llama2"}]}"#)
.with_body(raw_json_response)
.create();

let client = InferenceGatewayClient::new(&server.url());
Expand All @@ -554,11 +538,21 @@ mod tests {
#[tokio::test]
async fn test_generate_content() -> Result<(), GatewayError> {
let mut server = Server::new_async().await;

let raw_json_response = r#"{
"provider":"ollama",
"response":{
"role":"assistant",
"model":"llama2",
"content":"Hellloooo"
}
}"#;

let mock = server
.mock("POST", "/llms/ollama/generate")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(r#"{"provider":"ollama","response":{"role":"assistant","model":"llama2","content":"Hellloooo"}}"#)
.with_body(raw_json_response)
.create();

let client = InferenceGatewayClient::new(&server.url());
Expand All @@ -579,6 +573,144 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_generate_content_serialization() -> Result<(), GatewayError> {
let mut server = Server::new_async().await;

// Raw JSON response from API for debugging
let raw_json = r#"{
"provider": "groq",
"response": {
"role": "assistant",
"model": "mixtral-8x7b",
"content": "Hello"
}
}"#;

// Create mock with exact JSON structure
let mock = server
.mock("POST", "/llms/groq/generate")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(raw_json)
.create();

let client = InferenceGatewayClient::new(&server.url());

// Test direct JSON deserialization first
let direct_parse: Result<GenerateResponse, _> = serde_json::from_str(raw_json);
assert!(
direct_parse.is_ok(),
"Direct JSON parse failed: {:?}",
direct_parse.err()
);

// Test through client
let messages = vec![Message {
role: MessageRole::User,
content: "Hello".to_string(),
}];

let response = client
.generate_content(Provider::Groq, "mixtral-8x7b", messages)
.await?;

// Verify structure matches
assert_eq!(response.provider, Provider::Groq);
assert_eq!(response.response.role, MessageRole::Assistant);
assert_eq!(response.response.model, "mixtral-8x7b");
assert_eq!(response.response.content, "Hello");

mock.assert();
Ok(())
}

#[tokio::test]
async fn test_generate_content_error_response() -> Result<(), GatewayError> {
let mut server = Server::new_async().await;

let raw_json_response = r#"{
"error":"Invalid request"
}"#;

let mock = server
.mock("POST", "/llms/groq/generate")
.with_status(400)
.with_header("content-type", "application/json")
.with_body(raw_json_response)
.create();

let client = InferenceGatewayClient::new(&server.url());
let messages = vec![Message {
role: MessageRole::User,
content: "Hello".to_string(),
}];
let error = client
.generate_content(Provider::Groq, "mixtral-8x7b", messages)
.await
.unwrap_err();

assert!(matches!(error, GatewayError::BadRequest(_)));
if let GatewayError::BadRequest(msg) = error {
assert_eq!(msg, "Invalid request");
}
mock.assert();

Ok(())
}

#[tokio::test]
async fn test_gateway_errors() -> Result<(), GatewayError> {
let mut server: mockito::ServerGuard = Server::new_async().await;

// Test unauthorized error
let unauthorized_mock = server
.mock("GET", "/llms")
.with_status(401)
.with_header("content-type", "application/json")
.with_body(r#"{"error":"Invalid token"}"#)
.create();

let client = InferenceGatewayClient::new(&server.url());
match client.list_models().await {
Err(GatewayError::Unauthorized(msg)) => assert_eq!(msg, "Invalid token"),
_ => panic!("Expected Unauthorized error"),
}
unauthorized_mock.assert();

// Test bad request error
let bad_request_mock = server
.mock("GET", "/llms")
.with_status(400)
.with_header("content-type", "application/json")
.with_body(r#"{"error":"Invalid provider"}"#)
.create();

match client.list_models().await {
Err(GatewayError::BadRequest(msg)) => assert_eq!(msg, "Invalid provider"),
_ => panic!("Expected BadRequest error"),
}
bad_request_mock.assert();

// Test internal server error
let internal_error_mock = server
.mock("GET", "/llms")
.with_status(500)
.with_header("content-type", "application/json")
.with_body(r#"{"error":"Internal server error occurred"}"#)
.create();

match client.list_models().await {
Err(GatewayError::InternalError(msg)) => {
assert_eq!(msg, "Internal server error occurred")
}
_ => panic!("Expected InternalError error"),
}
internal_error_mock.assert();

Ok(())
}

#[tokio::test]
async fn test_health_check() -> Result<(), GatewayError> {
let mut server = Server::new_async().await;
Expand Down

0 comments on commit cdeb27c

Please sign in to comment.