diff --git a/litellm/llms/vertex_ai/vertex_ai_non_gemini.py b/litellm/llms/vertex_ai/vertex_ai_non_gemini.py index 418d8813dc1a..744e1eb3177d 100644 --- a/litellm/llms/vertex_ai/vertex_ai_non_gemini.py +++ b/litellm/llms/vertex_ai/vertex_ai_non_gemini.py @@ -7,6 +7,7 @@ import litellm from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.llms.bedrock.common_utils import ModelResponseIterator from litellm.llms.custom_httpx.http_handler import _DEFAULT_TTL_FOR_HTTPX_CLIENTS from litellm.types.llms.vertex_ai import * from litellm.utils import CustomStreamWrapper, ModelResponse, Usage @@ -197,6 +198,7 @@ def completion( # noqa: PLR0915 client_options = { "api_endpoint": f"{vertex_location}-aiplatform.googleapis.com" } + fake_stream = False if ( model in litellm.vertex_language_models or model in litellm.vertex_vision_models @@ -220,6 +222,7 @@ def completion( # noqa: PLR0915 ) mode = "text" request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" + fake_stream = True elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models llm_model = _vertex_llm_model_object or CodeChatModel.from_pretrained(model) mode = "chat" @@ -275,17 +278,22 @@ def completion( # noqa: PLR0915 return async_completion(**data) completion_response = None + + stream = optional_params.pop( + "stream", None + ) # See note above on handling streaming for vertex ai if mode == "chat": chat = llm_model.start_chat() request_str += "chat = llm_model.start_chat()\n" - if "stream" in optional_params and optional_params["stream"] is True: + if fake_stream is not True and stream is True: # NOTE: VertexAI does not accept stream=True as a param and raises an error, # we handle this by removing 'stream' from optional params and sending the request # after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format optional_params.pop( "stream", None ) # vertex ai raises an error when passing stream in optional params + request_str += ( f"chat.send_message_streaming({prompt}, **{optional_params})\n" ) @@ -298,6 +306,7 @@ def completion( # noqa: PLR0915 "request_str": request_str, }, ) + model_response = chat.send_message_streaming(prompt, **optional_params) return model_response @@ -314,10 +323,8 @@ def completion( # noqa: PLR0915 ) completion_response = chat.send_message(prompt, **optional_params).text elif mode == "text": - if "stream" in optional_params and optional_params["stream"] is True: - optional_params.pop( - "stream", None - ) # See note above on handling streaming for vertex ai + + if fake_stream is not True and stream is True: request_str += ( f"llm_model.predict_streaming({prompt}, **{optional_params})\n" ) @@ -384,7 +391,7 @@ def completion( # noqa: PLR0915 and "\nOutput:\n" in completion_response ): completion_response = completion_response.split("\nOutput:\n", 1)[1] - if "stream" in optional_params and optional_params["stream"] is True: + if stream is True: response = TextStreamer(completion_response) return response elif mode == "private": @@ -413,7 +420,7 @@ def completion( # noqa: PLR0915 and "\nOutput:\n" in completion_response ): completion_response = completion_response.split("\nOutput:\n", 1)[1] - if "stream" in optional_params and optional_params["stream"] is True: + if stream is True: response = TextStreamer(completion_response) return response @@ -465,6 +472,9 @@ def completion( # noqa: PLR0915 total_tokens=prompt_tokens + completion_tokens, ) setattr(model_response, "usage", usage) + + if fake_stream is True and stream is True: + return ModelResponseIterator(model_response) return model_response except Exception as e: if isinstance(e, VertexAIError): diff --git a/litellm/utils.py b/litellm/utils.py index f3789fe12940..2ab06ba89ad4 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4224,6 +4224,7 @@ def _get_model_info_helper( # noqa: PLR0915 _model_info: Optional[Dict[str, Any]] = None key: Optional[str] = None provider_config: Optional[BaseLLMModelInfo] = None + if combined_model_name in litellm.model_cost: key = combined_model_name _model_info = _get_model_info_from_model_cost(key=key) @@ -4263,7 +4264,10 @@ def _get_model_info_helper( # noqa: PLR0915 ): _model_info = None - if custom_llm_provider: + if custom_llm_provider and custom_llm_provider in [ + provider.value for provider in LlmProviders + ]: + # Check if the provider string exists in LlmProviders enum provider_config = ProviderConfigManager.get_provider_model_info( model=model, provider=LlmProviders(custom_llm_provider) ) diff --git a/tests/local_testing/test_amazing_vertex_completion.py b/tests/local_testing/test_amazing_vertex_completion.py index 0dc814961101..e7bca6bc13f5 100644 --- a/tests/local_testing/test_amazing_vertex_completion.py +++ b/tests/local_testing/test_amazing_vertex_completion.py @@ -930,7 +930,7 @@ async def test_gemini_pro_function_calling_httpx(model, sync_mode): "vertex_ai/mistral-large@2407", "vertex_ai/mistral-nemo@2407", "vertex_ai/codestral@2405", - "vertex_ai/meta/llama3-405b-instruct-maas", + # "vertex_ai/meta/llama3-405b-instruct-maas", ], # ) # "vertex_ai", @pytest.mark.parametrize( @@ -960,7 +960,6 @@ async def test_partner_models_httpx(model, sync_mode): "model": model, "messages": messages, "timeout": 10, - "mock_response": "Hello, how are you?", } if sync_mode: response = litellm.completion(**data) @@ -993,7 +992,8 @@ async def test_partner_models_httpx(model, sync_mode): "model", [ "vertex_ai/mistral-large@2407", - "vertex_ai/meta/llama3-405b-instruct-maas", + # "vertex_ai/meta/llama3-405b-instruct-maas", + "vertex_ai/codestral@2405", ], # ) # "vertex_ai", @pytest.mark.parametrize( @@ -1023,7 +1023,6 @@ async def test_partner_models_httpx_streaming(model, sync_mode): "model": model, "messages": messages, "stream": True, - "mock_response": "Hello, how are you?", } if sync_mode: response = litellm.completion(**data) @@ -3193,3 +3192,16 @@ async def test_vertexai_model_garden_model_completion( assert response.usage.completion_tokens == 109 assert response.usage.prompt_tokens == 63 assert response.usage.total_tokens == 172 + + +def test_vertexai_code_gecko(): + litellm.set_verbose = True + load_vertex_ai_credentials() + response = completion( + model="vertex_ai/code-gecko@002", + messages=[{"role": "user", "content": "Hello world!"}], + stream=True, + ) + + for chunk in response: + print(chunk) diff --git a/tests/local_testing/test_get_model_info.py b/tests/local_testing/test_get_model_info.py index 29c02c7cc5ad..2c313cecad1f 100644 --- a/tests/local_testing/test_get_model_info.py +++ b/tests/local_testing/test_get_model_info.py @@ -247,3 +247,41 @@ def test_model_info_bedrock_converse_enforcement(monkeypatch): ) except FileNotFoundError as e: pytest.skip("whitelisted_bedrock_models.txt not found") + + +def test_get_model_info_custom_provider(): + # Custom provider example copied from https://docs.litellm.ai/docs/providers/custom_llm_server: + import litellm + from litellm import CustomLLM, completion, get_llm_provider + + class MyCustomLLM(CustomLLM): + def completion(self, *args, **kwargs) -> litellm.ModelResponse: + return litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world"}], + mock_response="Hi!", + ) # type: ignore + + my_custom_llm = MyCustomLLM() + + litellm.custom_provider_map = [ # 👈 KEY STEP - REGISTER HANDLER + {"provider": "my-custom-llm", "custom_handler": my_custom_llm} + ] + + resp = completion( + model="my-custom-llm/my-fake-model", + messages=[{"role": "user", "content": "Hello world!"}], + ) + + assert resp.choices[0].message.content == "Hi!" + + # Register model info + model_info = {"my-custom-llm/my-fake-model": {"max_tokens": 2048}} + litellm.register_model(model_info) + + # Get registered model info + from litellm import get_model_info + + get_model_info( + model="my-custom-llm/my-fake-model" + ) # 💥 "Exception: This model isn't mapped yet." in v1.56.10