From db3c500fbfed01ba7fdf7e16c2f509f5851238c9 Mon Sep 17 00:00:00 2001 From: Yunmo Koo Date: Wed, 14 Aug 2024 11:32:56 +0900 Subject: [PATCH] Update patch version v1.5.1 --- friendli/cli/api/chat_completions.py | 11 +--- friendli/cli/api/completions.py | 11 +--- friendli/sdk/api/base.py | 84 ++++++------------------ friendli/sdk/api/chat/chat.py | 4 -- friendli/sdk/api/chat/completions.py | 6 +- friendli/sdk/api/completions.py | 6 +- friendli/sdk/api/images/images.py | 8 +-- friendli/sdk/api/images/text_to_image.py | 6 +- friendli/sdk/client.py | 26 ++++---- friendli/utils/decorator.py | 20 ------ pyproject.toml | 2 +- 11 files changed, 52 insertions(+), 132 deletions(-) diff --git a/friendli/cli/api/chat_completions.py b/friendli/cli/api/chat_completions.py index 8503e336..318866ef 100644 --- a/friendli/cli/api/chat_completions.py +++ b/friendli/cli/api/chat_completions.py @@ -13,7 +13,7 @@ from friendli.schema.api.v1.chat.completions import MessageParam from friendli.sdk.client import Friendli from friendli.utils.compat import model_dump -from friendli.utils.decorator import check_api, check_api_params +from friendli.utils.decorator import check_api from friendli.utils.format import secho_error_and_exit app = typer.Typer( @@ -24,7 +24,6 @@ @app.command() -@check_api_params @check_api def create( messages: List[str] = typer.Option( @@ -46,12 +45,6 @@ def create( "about available models and pricing." ), ), - endpoint_id: Optional[str] = typer.Option( - None, - "--endpoint-id", - "-e", - help="Dedicated endpoint ID to send request.", - ), n: Optional[int] = typer.Option( None, "--n", @@ -128,7 +121,7 @@ def create( ), ): """Creates chat completions.""" - client = Friendli(token=token, team_id=team_id, endpoint_id=endpoint_id) + client = Friendli(token=token, team_id=team_id) if enable_stream: stream = client.chat.completions.create( stream=True, diff --git a/friendli/cli/api/completions.py b/friendli/cli/api/completions.py index ba35751a..2981c28e 100644 --- a/friendli/cli/api/completions.py +++ b/friendli/cli/api/completions.py @@ -12,7 +12,7 @@ from friendli.sdk.client import Friendli from friendli.utils.compat import model_dump -from friendli.utils.decorator import check_api, check_api_params +from friendli.utils.decorator import check_api app = typer.Typer( no_args_is_help=True, @@ -22,7 +22,6 @@ @app.command() -@check_api_params @check_api def create( prompt: str = typer.Option( @@ -41,12 +40,6 @@ def create( "about available models and pricing." ), ), - endpoint_id: Optional[str] = typer.Option( - None, - "--endpoint-id", - "-e", - help="Dedicated endpoint ID to send request.", - ), n: Optional[int] = typer.Option( None, "--n", @@ -121,7 +114,7 @@ def create( team_id: Optional[str] = typer.Option(None, "--team", help="ID of team to run as."), ): """Creates text completions.""" - client = Friendli(token=token, team_id=team_id, endpoint_id=endpoint_id) + client = Friendli(token=token, team_id=team_id) if enable_stream: stream = client.completions.create( stream=True, diff --git a/friendli/sdk/api/base.py b/friendli/sdk/api/base.py index 8c803c50..8245dead 100644 --- a/friendli/sdk/api/base.py +++ b/friendli/sdk/api/base.py @@ -122,11 +122,9 @@ class BaseAPI(ABC, Generic[_HttpxClient, _ProtoMsgType]): def __init__( self, base_url: Optional[str] = None, - endpoint_id: Optional[str] = None, use_protobuf: bool = False, ) -> None: """Initializes BaseAPI.""" - self._endpoint_id = endpoint_id self._base_url = base_url self._use_protobuf = use_protobuf @@ -150,26 +148,20 @@ def _content_type(self) -> str: def _request_pb_cls(self) -> _ProtoMsgType: """Protobuf message class to serialize the data of request body.""" - def _build_http_request( - self, data: dict[str, Any], model: Optional[str] = None - ) -> httpx.Request: + def _build_http_request(self, data: dict[str, Any]) -> httpx.Request: """Build request.""" return self._http_client.build_request( method=self._method, url=self._build_http_url(), - content=self._build_content(data, model), + content=self._build_content(data), files=self._build_files(data), headers=self._get_headers(), ) def _build_http_url(self) -> httpx.URL: assert self._base_url is not None - path = "" - if self._endpoint_id is not None: - path = "dedicated" - path = os.path.join(path, self._api_path) - host = httpx.URL(self._base_url) - return host.join(path) + url = os.path.join(self._base_url, self._api_path) + return httpx.URL(url) def _build_grpc_url(self) -> str: if self._base_url is None: @@ -193,14 +185,7 @@ def _build_files(self, data: dict[str, Any]) -> dict[str, Any] | None: return files return None - def _build_content( - self, data: dict[str, Any], model: Optional[str] = None - ) -> bytes | None: - if self._endpoint_id is not None: - data["model"] = self._endpoint_id - else: - data["model"] = model - + def _build_content(self, data: dict[str, Any]) -> bytes | None: if self._content_type.startswith("multipart/form-data"): return None @@ -212,14 +197,7 @@ def _build_content( return json.dumps(data).encode() - def _build_grpc_request( - self, data: dict[str, Any], model: Optional[str] = None - ) -> pb_message.Message: - if self._endpoint_id is not None: - data["model"] = self._endpoint_id - else: - data["model"] = model - + def _build_grpc_request(self, data: dict[str, Any]) -> pb_message.Message: pb_cls = self._request_pb_cls return pb_cls(**data) @@ -230,7 +208,6 @@ class ServingAPI(BaseAPI[httpx.Client, _ProtoMsgType]): def __init__( self, base_url: Optional[str] = None, - endpoint_id: Optional[str] = None, use_protobuf: bool = False, use_grpc: bool = False, http_client: Optional[httpx.Client] = None, @@ -239,7 +216,6 @@ def __init__( """Initializes ServingAPI.""" super().__init__( base_url=base_url, - endpoint_id=endpoint_id, use_protobuf=use_protobuf, ) @@ -265,23 +241,12 @@ def close(self) -> None: def _get_grpc_stub(self, channel: grpc.Channel) -> Any: raise NotImplementedError # pragma: no cover - def _request( - self, *, data: dict[str, Any], stream: bool, model: Optional[str] = None - ) -> Any: + def _request(self, *, data: dict[str, Any], stream: bool) -> Any: # TODO: Add retry / handle timeout and etc. - if ( - self._base_url == "https://inference.friendli.ai" - and self._endpoint_id is None - and model is None - ): - raise ValueError("`model` is required for serverless endpoints.") - if self._endpoint_id is not None and model is not None: - raise ValueError("`model` is not allowed for dedicated endpoints.") - data = transform_request_data(data) if self._use_grpc: - grpc_request = self._build_grpc_request(data=data, model=model) + grpc_request = self._build_grpc_request(data=data) if not self._grpc_channel: self._grpc_channel = grpc.insecure_channel(self._build_grpc_url()) try: @@ -293,7 +258,7 @@ def _request( grpc_response = self._grpc_stub.Generate(grpc_request) return grpc_response - http_request = self._build_http_request(data=data, model=model) + http_request = self._build_http_request(data=data) http_response = self._http_client.send(request=http_request, stream=stream) self._check_http_error(http_response) return http_response @@ -303,10 +268,13 @@ def _check_http_error(self, response: httpx.Response) -> None: response.raise_for_status() except httpx.HTTPStatusError as exc: if response.status_code == 404: + endpoint_url = self._build_http_url() raise APIError( - "Endpoint is not found. This may be due to an invalid model name. " - "See https://docs.friendli.ai/guides/serverless_endpoints/pricing " - "to find out availble models." + f"Endpoint ({endpoint_url}) is not found. This may be due to an " + "invalid model name or endpoint ID. For serverless endpoints, see " + "https://docs.friendli.ai/guides/serverless_endpoints/pricing " + "to find out availble models. For dedicated endpoints, check your " + "endpoiont ID again." ) from exc resp_content = response.read() @@ -319,16 +287,13 @@ class AsyncServingAPI(BaseAPI[httpx.AsyncClient, _ProtoMsgType]): def __init__( self, base_url: Optional[str] = None, - endpoint_id: Optional[str] = None, use_protobuf: bool = False, use_grpc: bool = False, http_client: Optional[httpx.AsyncClient] = None, grpc_channel: Optional[grpc.aio.Channel] = None, ) -> None: """Initializes AsyncServingAPI.""" - super().__init__( - base_url=base_url, endpoint_id=endpoint_id, use_protobuf=use_protobuf - ) + super().__init__(base_url=base_url, use_protobuf=use_protobuf) self._use_grpc = use_grpc self._http_client = http_client or _DefaultAsyncHttpxClient() @@ -352,23 +317,12 @@ async def close(self) -> None: def _get_grpc_stub(self, channel: grpc.aio.Channel) -> Any: raise NotImplementedError # pragma: no cover - async def _request( - self, *, data: dict[str, Any], stream: bool, model: Optional[str] = None - ) -> Any: + async def _request(self, *, data: dict[str, Any], stream: bool) -> Any: # TODO: Add retry / handle timeout and etc. - if ( - self._base_url == "https://inference.friendli.ai" - and self._endpoint_id is None - and model is None - ): - raise ValueError("`model` is required for serverless endpoints.") - if self._endpoint_id is not None and model is not None: - raise ValueError("`model` is not allowed for dedicated endpoints.") - data = transform_request_data(data) if self._use_grpc: - grpc_request = self._build_grpc_request(data=data, model=model) + grpc_request = self._build_grpc_request(data=data) if not self._grpc_channel: self._grpc_channel = grpc.aio.insecure_channel(self._build_grpc_url()) try: @@ -382,7 +336,7 @@ async def _request( ) return grpc_response - http_request = self._build_http_request(data=data, model=model) + http_request = self._build_http_request(data=data) http_response = await self._http_client.send( request=http_request, stream=stream ) diff --git a/friendli/sdk/api/chat/chat.py b/friendli/sdk/api/chat/chat.py index 7351ac01..b1cd0612 100644 --- a/friendli/sdk/api/chat/chat.py +++ b/friendli/sdk/api/chat/chat.py @@ -21,7 +21,6 @@ class Chat: def __init__( self, base_url: Optional[str] = None, - endpoint_id: Optional[str] = None, use_protobuf: bool = False, use_grpc: bool = False, http_client: Optional[httpx.Client] = None, @@ -30,7 +29,6 @@ def __init__( """Initializes Chat.""" self.completions = Completions( base_url=base_url, - endpoint_id=endpoint_id, use_protobuf=use_protobuf, use_grpc=use_grpc, http_client=http_client, @@ -50,7 +48,6 @@ class AsyncChat: def __init__( self, base_url: Optional[str] = None, - endpoint_id: Optional[str] = None, use_protobuf: bool = False, use_grpc: bool = False, http_client: Optional[httpx.AsyncClient] = None, @@ -59,7 +56,6 @@ def __init__( """Initializes AsyncChat.""" self.completions = AsyncCompletions( base_url=base_url, - endpoint_id=endpoint_id, use_protobuf=use_protobuf, use_grpc=use_grpc, http_client=http_client, diff --git a/friendli/sdk/api/chat/completions.py b/friendli/sdk/api/chat/completions.py index e2a103a5..5fab769b 100644 --- a/friendli/sdk/api/chat/completions.py +++ b/friendli/sdk/api/chat/completions.py @@ -148,6 +148,7 @@ def create( request_dict = { "messages": messages, "stream": stream, + "model": model, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "repetition_penalty": repetition_penalty, @@ -166,7 +167,7 @@ def create( "tool_choice": tool_choice, "response_format": response_format, } - response = self._request(data=request_dict, stream=stream, model=model) + response = self._request(data=request_dict, stream=stream) if stream: return ChatCompletionStream(response=response) @@ -292,6 +293,7 @@ async def create( request_dict = { "messages": messages, "stream": stream, + "model": model, "frequency_penalty": frequency_penalty, "presence_penalty": presence_penalty, "repetition_penalty": repetition_penalty, @@ -310,7 +312,7 @@ async def create( "tool_choice": tool_choice, "response_format": response_format, } - response = await self._request(data=request_dict, stream=stream, model=model) + response = await self._request(data=request_dict, stream=stream) if stream: return AsyncChatCompletionStream(response=response) diff --git a/friendli/sdk/api/completions.py b/friendli/sdk/api/completions.py index bd872162..bf32436b 100644 --- a/friendli/sdk/api/completions.py +++ b/friendli/sdk/api/completions.py @@ -296,6 +296,7 @@ def create( request_dict = { "stream": stream, + "model": model, "prompt": prompt, "tokens": tokens, "timeout_microseconds": timeout_microseconds, @@ -331,7 +332,7 @@ def create( "forced_output_tokens": forced_output_tokens, "eos_token": eos_token, } - response = self._request(data=request_dict, stream=stream, model=model) + response = self._request(data=request_dict, stream=stream) if stream: if self._use_grpc: @@ -589,6 +590,7 @@ async def main() -> None: request_dict = { "stream": stream, + "model": model, "prompt": prompt, "tokens": tokens, "timeout_microseconds": timeout_microseconds, @@ -624,7 +626,7 @@ async def main() -> None: "forced_output_tokens": forced_output_tokens, "eos_token": eos_token, } - response = await self._request(data=request_dict, stream=stream, model=model) + response = await self._request(data=request_dict, stream=stream) if stream: if self._use_grpc: diff --git a/friendli/sdk/api/images/images.py b/friendli/sdk/api/images/images.py index fb58e658..e100baa4 100644 --- a/friendli/sdk/api/images/images.py +++ b/friendli/sdk/api/images/images.py @@ -19,13 +19,10 @@ class Images: def __init__( self, base_url: Optional[str] = None, - endpoint_id: Optional[str] = None, http_client: Optional[httpx.Client] = None, ) -> None: """Initialize Images.""" - self.text_to_image = TextToImage( - base_url=base_url, endpoint_id=endpoint_id, http_client=http_client - ) + self.text_to_image = TextToImage(base_url=base_url, http_client=http_client) def close(self) -> None: """Clean up all clients' resources.""" @@ -40,12 +37,11 @@ class AsyncImages: def __init__( self, base_url: Optional[str] = None, - endpoint_id: Optional[str] = None, http_client: Optional[httpx.AsyncClient] = None, ) -> None: """Initialize Images.""" self.text_to_image = AsyncTextToImage( - base_url=base_url, endpoint_id=endpoint_id, http_client=http_client + base_url=base_url, http_client=http_client ) async def close(self) -> None: diff --git a/friendli/sdk/api/images/text_to_image.py b/friendli/sdk/api/images/text_to_image.py index 16913be4..f0e432e6 100644 --- a/friendli/sdk/api/images/text_to_image.py +++ b/friendli/sdk/api/images/text_to_image.py @@ -65,6 +65,7 @@ def create( """ request_dict = { "prompt": prompt, + "model": model, "negative_prompt": negative_prompt, "num_outputs": num_outputs, "num_inference_steps": num_inference_steps, @@ -72,7 +73,7 @@ def create( "seed": seed, "response_format": response_format, } - response = self._request(data=request_dict, stream=False, model=model) + response = self._request(data=request_dict, stream=False) return model_parse(Image, response.json()) @@ -127,6 +128,7 @@ async def create( """ request_dict = { "prompt": prompt, + "model": model, "negative_prompt": negative_prompt, "num_outputs": num_outputs, "num_inference_steps": num_inference_steps, @@ -134,6 +136,6 @@ async def create( "seed": seed, "response_format": response_format, } - response = await self._request(data=request_dict, stream=False, model=model) + response = await self._request(data=request_dict, stream=False) return model_parse(Image, response.json()) diff --git a/friendli/sdk/client.py b/friendli/sdk/client.py index 6930641c..146342fd 100644 --- a/friendli/sdk/client.py +++ b/friendli/sdk/client.py @@ -4,6 +4,7 @@ from __future__ import annotations +import os from typing import Optional, Union import grpc @@ -31,7 +32,7 @@ def __init__( token: Optional[str] = None, team_id: Optional[str] = None, project_id: Optional[str] = None, - endpoint_id: Optional[str] = None, + use_dedicated_endpoint: bool = False, base_url: Optional[str] = None, use_protobuf: bool = False, use_grpc: bool = False, @@ -45,7 +46,7 @@ def __init__( friendli.team_id = team_id if project_id is not None: friendli.project_id = project_id - self._endpoint_id = endpoint_id + self._use_dedicated_endpoint = use_dedicated_endpoint self._base_url = base_url self._use_protobuf = use_protobuf @@ -56,6 +57,10 @@ def __init__( ) if http_client is not None: raise ValueError("You cannot use HTTP client when `use_grpc=True`.") + if use_dedicated_endpoint: + raise ValueError( + "`use_grpc=True` is not allowed for dedicated endpoints." + ) else: if grpc_channel is not None: raise ValueError( @@ -64,6 +69,9 @@ def __init__( if base_url is None: self._base_url = INFERENCE_ENDPOINT_URL + if use_dedicated_endpoint: + self._base_url = os.path.join(self._base_url, "dedicated") + class Friendli(FriendliClientBase): """Friendli API client.""" @@ -80,7 +88,7 @@ def __init__( token: Optional[str] = None, team_id: Optional[str] = None, project_id: Optional[str] = None, - endpoint_id: Optional[str] = None, + use_dedicated_endpoint: bool = False, base_url: Optional[str] = None, use_protobuf: bool = False, use_grpc: bool = False, @@ -92,7 +100,7 @@ def __init__( token=token, team_id=team_id, project_id=project_id, - endpoint_id=endpoint_id, + use_dedicated_endpoint=use_dedicated_endpoint, base_url=base_url, use_protobuf=use_protobuf, use_grpc=use_grpc, @@ -102,7 +110,6 @@ def __init__( self.completions = Completions( base_url=self._base_url, - endpoint_id=self._endpoint_id, use_protobuf=use_protobuf, use_grpc=use_grpc, http_client=http_client, @@ -110,7 +117,6 @@ def __init__( ) self.chat = Chat( base_url=self._base_url, - endpoint_id=self._endpoint_id, use_protobuf=use_protobuf, use_grpc=use_grpc, http_client=http_client, @@ -118,7 +124,6 @@ def __init__( ) self.images = Images( base_url=self._base_url, - endpoint_id=self._endpoint_id, http_client=http_client, ) @@ -155,7 +160,7 @@ def __init__( token: Optional[str] = None, team_id: Optional[str] = None, project_id: Optional[str] = None, - endpoint_id: Optional[str] = None, + use_dedicated_endpoint: bool = False, base_url: Optional[str] = None, use_protobuf: bool = False, use_grpc: bool = False, @@ -167,7 +172,7 @@ def __init__( token=token, team_id=team_id, project_id=project_id, - endpoint_id=endpoint_id, + use_dedicated_endpoint=use_dedicated_endpoint, base_url=base_url, use_protobuf=use_protobuf, use_grpc=use_grpc, @@ -177,7 +182,6 @@ def __init__( self.completions = AsyncCompletions( base_url=self._base_url, - endpoint_id=self._endpoint_id, use_protobuf=use_protobuf, use_grpc=use_grpc, http_client=http_client, @@ -185,7 +189,6 @@ def __init__( ) self.chat = AsyncChat( base_url=self._base_url, - endpoint_id=self._endpoint_id, use_protobuf=use_protobuf, use_grpc=use_grpc, http_client=http_client, @@ -193,7 +196,6 @@ def __init__( ) self.images = AsyncImages( base_url=self._base_url, - endpoint_id=self._endpoint_id, http_client=http_client, ) diff --git a/friendli/utils/decorator.py b/friendli/utils/decorator.py index a1d07221..8075fc89 100644 --- a/friendli/utils/decorator.py +++ b/friendli/utils/decorator.py @@ -22,23 +22,3 @@ def inner(*args, **kwargs) -> Any: secho_error_and_exit(str(exc)) return inner - - -def check_api_params(func: Callable[..., Any]) -> Callable[..., Any]: - """Check API params.""" - - @functools.wraps(func) - def inner(*args, **kwargs) -> Any: - model = kwargs["model"] - endpoint_id = kwargs["endpoint_id"] - - if model is None and endpoint_id is None: - secho_error_and_exit("One of 'model' and 'endpoint_id' should be provided.") - if model is not None and endpoint_id is not None: - secho_error_and_exit( - "Only one of 'model' and 'endpoint_id' should be provided." - ) - - return func(*args, **kwargs) - - return inner diff --git a/pyproject.toml b/pyproject.toml index a6246f77..6197dc0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "friendli-client" -version = "1.5.0" +version = "1.5.1" description = "Client of Friendli Suite." license = "Apache-2.0" authors = ["FriendliAI teams "]