Skip to content

Commit

Permalink
chore(lint): add linting task using Black for code formatting
Browse files Browse the repository at this point in the history
Signed-off-by: Eden Reich <[email protected]>
  • Loading branch information
edenreich committed Jan 21, 2025
1 parent 0166d9e commit 642ed74
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 26 deletions.
5 changes: 5 additions & 0 deletions Taskfile.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,11 @@ tasks:
cmds:
- curl -o openapi.yaml https://raw.githubusercontent.com/inference-gateway/inference-gateway/refs/heads/main/openapi.yaml

lint:
desc: Lint the code
cmds:
- black inference_gateway/ tests/

test:
desc: Run tests
cmds:
Expand Down
24 changes: 7 additions & 17 deletions inference_gateway/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

class Provider(str, Enum):
"""Supported LLM providers"""

OLLAMA = "ollama"
GROQ = "groq"
OPENAI = "openai"
Expand All @@ -16,6 +17,7 @@ class Provider(str, Enum):

class Role(str, Enum):
"""Message role types"""

SYSTEM = "system"
USER = "user"
ASSISTANT = "assistant"
Expand All @@ -28,10 +30,7 @@ class Message:

def to_dict(self) -> Dict[str, str]:
"""Convert message to dictionary format with string values"""
return {
"role": self.role.value,
"content": self.content
}
return {"role": self.role.value, "content": self.content}


class Model:
Expand All @@ -57,7 +56,7 @@ class InferenceGatewayClient:

def __init__(self, base_url: str, token: Optional[str] = None):
"""Initialize the client with base URL and optional auth token"""
self.base_url = base_url.rstrip('/')
self.base_url = base_url.rstrip("/")
self.session = requests.Session()
if token:
self.session.headers.update({"Authorization": f"Bearer {token}"})
Expand All @@ -68,20 +67,11 @@ def list_models(self) -> List[ProviderModels]:
response.raise_for_status()
return response.json()

def generate_content(
self,
provider: Provider,
model: str,
messages: List[Message]
) -> Dict:
payload = {
"model": model,
"messages": [msg.to_dict() for msg in messages]
}
def generate_content(self, provider: Provider, model: str, messages: List[Message]) -> Dict:
payload = {"model": model, "messages": [msg.to_dict() for msg in messages]}

response = self.session.post(
f"{self.base_url}/llms/{provider.value}/generate",
json=payload
f"{self.base_url}/llms/{provider.value}/generate", json=payload
)
response.raise_for_status()
return response.json()
Expand Down
14 changes: 5 additions & 9 deletions tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ def test_client_initialization():
assert client.base_url == "http://test-api"
assert "Authorization" not in client.session.headers

client_with_token = InferenceGatewayClient(
"http://test-api", token="test-token")
client_with_token = InferenceGatewayClient("http://test-api", token="test-token")
assert "Authorization" in client_with_token.session.headers
assert client_with_token.session.headers["Authorization"] == "Bearer test-token"

Expand All @@ -43,10 +42,7 @@ def test_list_models(mock_get, client, mock_response):
@patch("requests.Session.post")
def test_generate_content(mock_post, client, mock_response):
"""Test content generation"""
messages = [
Message(Role.SYSTEM, "You are a helpful assistant"),
Message(Role.USER, "Hello!")
]
messages = [Message(Role.SYSTEM, "You are a helpful assistant"), Message(Role.USER, "Hello!")]

mock_post.return_value = mock_response
response = client.generate_content(Provider.OPENAI, "gpt-4", messages)
Expand All @@ -57,9 +53,9 @@ def test_generate_content(mock_post, client, mock_response):
"model": "gpt-4",
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello!"}
]
}
{"role": "user", "content": "Hello!"},
],
},
)
assert response == {"response": "test"}

Expand Down

0 comments on commit 642ed74

Please sign in to comment.