Skip to content

Commit

Permalink
Update patch version v1.5.1
Browse files Browse the repository at this point in the history
  • Loading branch information
kooyunmo committed Aug 14, 2024
1 parent 0d8f5f9 commit db3c500
Show file tree
Hide file tree
Showing 11 changed files with 52 additions and 132 deletions.
11 changes: 2 additions & 9 deletions friendli/cli/api/chat_completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from friendli.schema.api.v1.chat.completions import MessageParam
from friendli.sdk.client import Friendli
from friendli.utils.compat import model_dump
from friendli.utils.decorator import check_api, check_api_params
from friendli.utils.decorator import check_api
from friendli.utils.format import secho_error_and_exit

app = typer.Typer(
Expand All @@ -24,7 +24,6 @@


@app.command()
@check_api_params
@check_api
def create(
messages: List[str] = typer.Option(
Expand All @@ -46,12 +45,6 @@ def create(
"about available models and pricing."
),
),
endpoint_id: Optional[str] = typer.Option(
None,
"--endpoint-id",
"-e",
help="Dedicated endpoint ID to send request.",
),
n: Optional[int] = typer.Option(
None,
"--n",
Expand Down Expand Up @@ -128,7 +121,7 @@ def create(
),
):
"""Creates chat completions."""
client = Friendli(token=token, team_id=team_id, endpoint_id=endpoint_id)
client = Friendli(token=token, team_id=team_id)
if enable_stream:
stream = client.chat.completions.create(
stream=True,
Expand Down
11 changes: 2 additions & 9 deletions friendli/cli/api/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from friendli.sdk.client import Friendli
from friendli.utils.compat import model_dump
from friendli.utils.decorator import check_api, check_api_params
from friendli.utils.decorator import check_api

app = typer.Typer(
no_args_is_help=True,
Expand All @@ -22,7 +22,6 @@


@app.command()
@check_api_params
@check_api
def create(
prompt: str = typer.Option(
Expand All @@ -41,12 +40,6 @@ def create(
"about available models and pricing."
),
),
endpoint_id: Optional[str] = typer.Option(
None,
"--endpoint-id",
"-e",
help="Dedicated endpoint ID to send request.",
),
n: Optional[int] = typer.Option(
None,
"--n",
Expand Down Expand Up @@ -121,7 +114,7 @@ def create(
team_id: Optional[str] = typer.Option(None, "--team", help="ID of team to run as."),
):
"""Creates text completions."""
client = Friendli(token=token, team_id=team_id, endpoint_id=endpoint_id)
client = Friendli(token=token, team_id=team_id)
if enable_stream:
stream = client.completions.create(
stream=True,
Expand Down
84 changes: 19 additions & 65 deletions friendli/sdk/api/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,9 @@ class BaseAPI(ABC, Generic[_HttpxClient, _ProtoMsgType]):
def __init__(
self,
base_url: Optional[str] = None,
endpoint_id: Optional[str] = None,
use_protobuf: bool = False,
) -> None:
"""Initializes BaseAPI."""
self._endpoint_id = endpoint_id
self._base_url = base_url
self._use_protobuf = use_protobuf

Expand All @@ -150,26 +148,20 @@ def _content_type(self) -> str:
def _request_pb_cls(self) -> _ProtoMsgType:
"""Protobuf message class to serialize the data of request body."""

def _build_http_request(
self, data: dict[str, Any], model: Optional[str] = None
) -> httpx.Request:
def _build_http_request(self, data: dict[str, Any]) -> httpx.Request:
"""Build request."""
return self._http_client.build_request(
method=self._method,
url=self._build_http_url(),
content=self._build_content(data, model),
content=self._build_content(data),
files=self._build_files(data),
headers=self._get_headers(),
)

def _build_http_url(self) -> httpx.URL:
assert self._base_url is not None
path = ""
if self._endpoint_id is not None:
path = "dedicated"
path = os.path.join(path, self._api_path)
host = httpx.URL(self._base_url)
return host.join(path)
url = os.path.join(self._base_url, self._api_path)
return httpx.URL(url)

def _build_grpc_url(self) -> str:
if self._base_url is None:
Expand All @@ -193,14 +185,7 @@ def _build_files(self, data: dict[str, Any]) -> dict[str, Any] | None:
return files
return None

def _build_content(
self, data: dict[str, Any], model: Optional[str] = None
) -> bytes | None:
if self._endpoint_id is not None:
data["model"] = self._endpoint_id
else:
data["model"] = model

def _build_content(self, data: dict[str, Any]) -> bytes | None:
if self._content_type.startswith("multipart/form-data"):
return None

Expand All @@ -212,14 +197,7 @@ def _build_content(

return json.dumps(data).encode()

def _build_grpc_request(
self, data: dict[str, Any], model: Optional[str] = None
) -> pb_message.Message:
if self._endpoint_id is not None:
data["model"] = self._endpoint_id
else:
data["model"] = model

def _build_grpc_request(self, data: dict[str, Any]) -> pb_message.Message:
pb_cls = self._request_pb_cls
return pb_cls(**data)

Expand All @@ -230,7 +208,6 @@ class ServingAPI(BaseAPI[httpx.Client, _ProtoMsgType]):
def __init__(
self,
base_url: Optional[str] = None,
endpoint_id: Optional[str] = None,
use_protobuf: bool = False,
use_grpc: bool = False,
http_client: Optional[httpx.Client] = None,
Expand All @@ -239,7 +216,6 @@ def __init__(
"""Initializes ServingAPI."""
super().__init__(
base_url=base_url,
endpoint_id=endpoint_id,
use_protobuf=use_protobuf,
)

Expand All @@ -265,23 +241,12 @@ def close(self) -> None:
def _get_grpc_stub(self, channel: grpc.Channel) -> Any:
raise NotImplementedError # pragma: no cover

def _request(
self, *, data: dict[str, Any], stream: bool, model: Optional[str] = None
) -> Any:
def _request(self, *, data: dict[str, Any], stream: bool) -> Any:
# TODO: Add retry / handle timeout and etc.
if (
self._base_url == "https://inference.friendli.ai"
and self._endpoint_id is None
and model is None
):
raise ValueError("`model` is required for serverless endpoints.")
if self._endpoint_id is not None and model is not None:
raise ValueError("`model` is not allowed for dedicated endpoints.")

data = transform_request_data(data)

if self._use_grpc:
grpc_request = self._build_grpc_request(data=data, model=model)
grpc_request = self._build_grpc_request(data=data)
if not self._grpc_channel:
self._grpc_channel = grpc.insecure_channel(self._build_grpc_url())
try:
Expand All @@ -293,7 +258,7 @@ def _request(
grpc_response = self._grpc_stub.Generate(grpc_request)
return grpc_response

http_request = self._build_http_request(data=data, model=model)
http_request = self._build_http_request(data=data)
http_response = self._http_client.send(request=http_request, stream=stream)
self._check_http_error(http_response)
return http_response
Expand All @@ -303,10 +268,13 @@ def _check_http_error(self, response: httpx.Response) -> None:
response.raise_for_status()
except httpx.HTTPStatusError as exc:
if response.status_code == 404:
endpoint_url = self._build_http_url()
raise APIError(
"Endpoint is not found. This may be due to an invalid model name. "
"See https://docs.friendli.ai/guides/serverless_endpoints/pricing "
"to find out availble models."
f"Endpoint ({endpoint_url}) is not found. This may be due to an "
"invalid model name or endpoint ID. For serverless endpoints, see "
"https://docs.friendli.ai/guides/serverless_endpoints/pricing "
"to find out availble models. For dedicated endpoints, check your "
"endpoiont ID again."
) from exc

resp_content = response.read()
Expand All @@ -319,16 +287,13 @@ class AsyncServingAPI(BaseAPI[httpx.AsyncClient, _ProtoMsgType]):
def __init__(
self,
base_url: Optional[str] = None,
endpoint_id: Optional[str] = None,
use_protobuf: bool = False,
use_grpc: bool = False,
http_client: Optional[httpx.AsyncClient] = None,
grpc_channel: Optional[grpc.aio.Channel] = None,
) -> None:
"""Initializes AsyncServingAPI."""
super().__init__(
base_url=base_url, endpoint_id=endpoint_id, use_protobuf=use_protobuf
)
super().__init__(base_url=base_url, use_protobuf=use_protobuf)

self._use_grpc = use_grpc
self._http_client = http_client or _DefaultAsyncHttpxClient()
Expand All @@ -352,23 +317,12 @@ async def close(self) -> None:
def _get_grpc_stub(self, channel: grpc.aio.Channel) -> Any:
raise NotImplementedError # pragma: no cover

async def _request(
self, *, data: dict[str, Any], stream: bool, model: Optional[str] = None
) -> Any:
async def _request(self, *, data: dict[str, Any], stream: bool) -> Any:
# TODO: Add retry / handle timeout and etc.
if (
self._base_url == "https://inference.friendli.ai"
and self._endpoint_id is None
and model is None
):
raise ValueError("`model` is required for serverless endpoints.")
if self._endpoint_id is not None and model is not None:
raise ValueError("`model` is not allowed for dedicated endpoints.")

data = transform_request_data(data)

if self._use_grpc:
grpc_request = self._build_grpc_request(data=data, model=model)
grpc_request = self._build_grpc_request(data=data)
if not self._grpc_channel:
self._grpc_channel = grpc.aio.insecure_channel(self._build_grpc_url())
try:
Expand All @@ -382,7 +336,7 @@ async def _request(
)
return grpc_response

http_request = self._build_http_request(data=data, model=model)
http_request = self._build_http_request(data=data)
http_response = await self._http_client.send(
request=http_request, stream=stream
)
Expand Down
4 changes: 0 additions & 4 deletions friendli/sdk/api/chat/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ class Chat:
def __init__(
self,
base_url: Optional[str] = None,
endpoint_id: Optional[str] = None,
use_protobuf: bool = False,
use_grpc: bool = False,
http_client: Optional[httpx.Client] = None,
Expand All @@ -30,7 +29,6 @@ def __init__(
"""Initializes Chat."""
self.completions = Completions(
base_url=base_url,
endpoint_id=endpoint_id,
use_protobuf=use_protobuf,
use_grpc=use_grpc,
http_client=http_client,
Expand All @@ -50,7 +48,6 @@ class AsyncChat:
def __init__(
self,
base_url: Optional[str] = None,
endpoint_id: Optional[str] = None,
use_protobuf: bool = False,
use_grpc: bool = False,
http_client: Optional[httpx.AsyncClient] = None,
Expand All @@ -59,7 +56,6 @@ def __init__(
"""Initializes AsyncChat."""
self.completions = AsyncCompletions(
base_url=base_url,
endpoint_id=endpoint_id,
use_protobuf=use_protobuf,
use_grpc=use_grpc,
http_client=http_client,
Expand Down
6 changes: 4 additions & 2 deletions friendli/sdk/api/chat/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ def create(
request_dict = {
"messages": messages,
"stream": stream,
"model": model,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"repetition_penalty": repetition_penalty,
Expand All @@ -166,7 +167,7 @@ def create(
"tool_choice": tool_choice,
"response_format": response_format,
}
response = self._request(data=request_dict, stream=stream, model=model)
response = self._request(data=request_dict, stream=stream)

if stream:
return ChatCompletionStream(response=response)
Expand Down Expand Up @@ -292,6 +293,7 @@ async def create(
request_dict = {
"messages": messages,
"stream": stream,
"model": model,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"repetition_penalty": repetition_penalty,
Expand All @@ -310,7 +312,7 @@ async def create(
"tool_choice": tool_choice,
"response_format": response_format,
}
response = await self._request(data=request_dict, stream=stream, model=model)
response = await self._request(data=request_dict, stream=stream)

if stream:
return AsyncChatCompletionStream(response=response)
Expand Down
6 changes: 4 additions & 2 deletions friendli/sdk/api/completions.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,7 @@ def create(

request_dict = {
"stream": stream,
"model": model,
"prompt": prompt,
"tokens": tokens,
"timeout_microseconds": timeout_microseconds,
Expand Down Expand Up @@ -331,7 +332,7 @@ def create(
"forced_output_tokens": forced_output_tokens,
"eos_token": eos_token,
}
response = self._request(data=request_dict, stream=stream, model=model)
response = self._request(data=request_dict, stream=stream)

if stream:
if self._use_grpc:
Expand Down Expand Up @@ -589,6 +590,7 @@ async def main() -> None:

request_dict = {
"stream": stream,
"model": model,
"prompt": prompt,
"tokens": tokens,
"timeout_microseconds": timeout_microseconds,
Expand Down Expand Up @@ -624,7 +626,7 @@ async def main() -> None:
"forced_output_tokens": forced_output_tokens,
"eos_token": eos_token,
}
response = await self._request(data=request_dict, stream=stream, model=model)
response = await self._request(data=request_dict, stream=stream)

if stream:
if self._use_grpc:
Expand Down
Loading

0 comments on commit db3c500

Please sign in to comment.