Skip to content

Commit

Permalink
feat(providers): Groq now uses LiteLLM openai-compat (#1303)
Browse files Browse the repository at this point in the history
Groq has never supported raw completions anyhow. So this makes it easier
to switch it to LiteLLM. All our test suite passes.

I also updated all the openai-compat providers so they work with api
keys passed from headers. `provider_data`

## Test Plan

```bash
LLAMA_STACK_CONFIG=groq \
   pytest -s -v tests/client-sdk/inference/test_text_inference.py \
   --inference-model=groq/llama-3.3-70b-versatile --vision-inference-model=""
```

Also tested (openai, anthropic, gemini) providers. No regressions.
  • Loading branch information
ashwinb authored Feb 27, 2025
1 parent 564f0e5 commit 928a39d
Show file tree
Hide file tree
Showing 23 changed files with 165 additions and 1,004 deletions.
1 change: 1 addition & 0 deletions distributions/dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
"fastapi",
"fire",
"fireworks-ai",
"groq",
"httpx",
"litellm",
"matplotlib",
Expand Down
10 changes: 5 additions & 5 deletions docs/source/distributions/self_hosted_distro/groq.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ The following environment variables can be configured:

The following models are available by default:

- `meta-llama/Llama-3.1-8B-Instruct (llama3-8b-8192)`
- `meta-llama/Llama-3.1-8B-Instruct (llama-3.1-8b-instant)`
- `meta-llama/Llama-3-70B-Instruct (llama3-70b-8192)`
- `meta-llama/Llama-3.3-70B-Instruct (llama-3.3-70b-versatile)`
- `meta-llama/Llama-3.2-3B-Instruct (llama-3.2-3b-preview)`
- `meta-llama/Llama-3.1-8B-Instruct (groq/llama3-8b-8192)`
- `meta-llama/Llama-3.1-8B-Instruct (groq/llama-3.1-8b-instant)`
- `meta-llama/Llama-3-70B-Instruct (groq/llama3-70b-8192)`
- `meta-llama/Llama-3.3-70B-Instruct (groq/llama-3.3-70b-versatile)`
- `meta-llama/Llama-3.2-3B-Instruct (groq/llama-3.2-3b-preview)`


### Prerequisite: API Keys
Expand Down
23 changes: 13 additions & 10 deletions llama_stack/providers/registry/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,6 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq",
pip_packages=["groq"],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.GroqProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
Expand Down Expand Up @@ -214,6 +204,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
),
),
remote_provider_spec(
Expand All @@ -223,6 +214,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
),
),
remote_provider_spec(
Expand All @@ -232,6 +224,17 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq",
pip_packages=["groq"],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
),
),
remote_provider_spec(
Expand Down
11 changes: 8 additions & 3 deletions llama_stack/providers/remote/inference/anthropic/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@

class AnthropicInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: AnthropicConfig) -> None:
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="anthropic_api_key",
)
self.config = config

async def initialize(self) -> None:
pass
await super().initialize()

async def shutdown(self) -> None:
pass
await super().shutdown()
7 changes: 7 additions & 0 deletions llama_stack/providers/remote/inference/anthropic/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from llama_stack.schema_utils import json_schema_type


class AnthropicProviderDataValidator(BaseModel):
anthropic_api_key: Optional[str] = Field(
default=None,
description="API key for Anthropic models",
)


@json_schema_type
class AnthropicConfig(BaseModel):
api_key: Optional[str] = Field(
Expand Down
7 changes: 7 additions & 0 deletions llama_stack/providers/remote/inference/gemini/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from llama_stack.schema_utils import json_schema_type


class GeminiProviderDataValidator(BaseModel):
gemini_api_key: Optional[str] = Field(
default=None,
description="API key for Gemini models",
)


@json_schema_type
class GeminiConfig(BaseModel):
api_key: Optional[str] = Field(
Expand Down
11 changes: 8 additions & 3 deletions llama_stack/providers/remote/inference/gemini/gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@

class GeminiInferenceAdapter(LiteLLMOpenAIMixin):
def __init__(self, config: GeminiConfig) -> None:
LiteLLMOpenAIMixin.__init__(self, MODEL_ENTRIES)
LiteLLMOpenAIMixin.__init__(
self,
MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="gemini_api_key",
)
self.config = config

async def initialize(self) -> None:
pass
await super().initialize()

async def shutdown(self) -> None:
pass
await super().shutdown()
9 changes: 0 additions & 9 deletions llama_stack/providers/remote/inference/groq/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

from pydantic import BaseModel

from llama_stack.apis.inference import Inference

from .config import GroqConfig


class GroqProviderDataValidator(BaseModel):
groq_api_key: str


async def get_adapter_impl(config: GroqConfig, _deps) -> Inference:
# import dynamically so the import is used only when it is needed
from .groq import GroqInferenceAdapter

if not isinstance(config, GroqConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")

adapter = GroqInferenceAdapter(config)
return adapter
11 changes: 9 additions & 2 deletions llama_stack/providers/remote/inference/groq/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,13 @@
from llama_stack.schema_utils import json_schema_type


class GroqProviderDataValidator(BaseModel):
groq_api_key: Optional[str] = Field(
default=None,
description="API key for Groq models",
)


@json_schema_type
class GroqConfig(BaseModel):
api_key: Optional[str] = Field(
Expand All @@ -25,8 +32,8 @@ class GroqConfig(BaseModel):
)

@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> Dict[str, Any]:
return {
"url": "https://api.groq.com",
"api_key": "${env.GROQ_API_KEY}",
"api_key": api_key,
}
130 changes: 13 additions & 117 deletions llama_stack/providers/remote/inference/groq/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,130 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

import warnings
from typing import AsyncIterator, List, Optional, Union

import groq
from groq import Groq

from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionResponse,
CompletionResponseStreamChunk,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
InterleavedContent,
InterleavedContentItem,
LogProbConfig,
Message,
ResponseFormat,
TextTruncation,
ToolChoice,
ToolConfig,
)
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.models.llama.datatypes import SamplingParams, ToolDefinition, ToolPromptFormat
from llama_stack.providers.remote.inference.groq.config import GroqConfig
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin

from .groq_utils import (
convert_chat_completion_request,
convert_chat_completion_response,
convert_chat_completion_response_stream,
)
from .models import _MODEL_ENTRIES
from .models import MODEL_ENTRIES


class GroqInferenceAdapter(Inference, ModelRegistryHelper, NeedsRequestProviderData):
class GroqInferenceAdapter(LiteLLMOpenAIMixin):
_config: GroqConfig

def __init__(self, config: GroqConfig):
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
self._config = config

def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
# Groq doesn't support non-chat completion as of time of writing
raise NotImplementedError()

async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
model_id = self.get_provider_model_id(model_id)
if model_id == "llama-3.2-3b-preview":
warnings.warn(
"Groq only contains a preview version for llama-3.2-3b-instruct. "
"Preview models aren't recommended for production use. "
"They can be discontinued on short notice."
"More details: https://console.groq.com/docs/models"
)

request = convert_chat_completion_request(
request=ChatCompletionRequest(
model=model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
LiteLLMOpenAIMixin.__init__(
self,
model_entries=MODEL_ENTRIES,
api_key_from_config=config.api_key,
provider_data_api_key_field="groq_api_key",
)
self.config = config

try:
response = self._get_client().chat.completions.create(**request)
except groq.BadRequestError as e:
if e.body.get("error", {}).get("code") == "tool_use_failed":
# For smaller models, Groq may fail to call a tool even when the request is well formed
raise ValueError("Groq failed to call a tool", e.body.get("error", {})) from e
else:
raise e

if stream:
return convert_chat_completion_response_stream(response)
else:
return convert_chat_completion_response(response)

async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError()
async def initialize(self):
await super().initialize()

def _get_client(self) -> Groq:
if self._config.api_key is not None:
return Groq(api_key=self._config.api_key)
else:
provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.groq_api_key:
raise ValueError(
'Pass Groq API Key in the header X-LlamaStack-Provider-Data as { "groq_api_key": "<your api key>" }'
)
return Groq(api_key=provider_data.groq_api_key)
async def shutdown(self):
await super().shutdown()
Loading

0 comments on commit 928a39d

Please sign in to comment.