Skip to content

Commit

Permalink
feat(adapter): add get_model_list method
Browse files Browse the repository at this point in the history
feat(openai): remove MODEL constant

feat(user): refactor get_model method

The changes in this commit include:
- Added `get_model_list` method to the `adapter` module in `llmkira/sdk/adapter.py` file.
- Removed the `MODEL` constant from the `openai` module in `llmkira/sdk/endpoint/openai/__init__.py` file.
- Refactored the `get_model` method in the `user` module in `llmkira/extra/user/__init__.py` file.
  • Loading branch information
sudoskys committed Nov 12, 2023
1 parent 020a6ca commit d785885
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 21 deletions.
9 changes: 5 additions & 4 deletions llmkira/extra/user/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from typing import List, Union, Optional
from urllib.parse import urlparse

from llmkira.sdk.endpoint.openai import MODEL
from llmkira.sdk.adapter import SCHEMA_GROUP
from llmkira.sdk.func_calling import ToolRegister

from .client import UserCostClient, UserConfigClient, UserCost, UserConfig
from .schema import UserDriverMode
from ...sdk.endpoint import Driver
Expand Down Expand Up @@ -40,7 +41,7 @@ async def get_cost_by_uid(uid: str) -> List[UserCost]:
class UserControl(object):
@staticmethod
def get_model():
return MODEL.__args__
return SCHEMA_GROUP.get_model_list()

@staticmethod
async def get_driver_config(
Expand Down Expand Up @@ -105,8 +106,8 @@ async def set_endpoint(
:return: new_driver
"""
# assert model in MODEL.__args__, f"openai model is not valid,must be one of {MODEL.__args__}"
if model not in MODEL.__args__:
model = MODEL.__args__[0]
if model not in UserControl.get_model():
model = UserControl.get_model()[0]
_user_data = await UserConfigClient().read_by_uid(uid=uid)
_user_data = _user_data or UserConfig(uid=uid)
new_driver = Driver(endpoint=endpoint, api_key=api_key, model=model, org_id=org_id)
Expand Down
3 changes: 3 additions & 0 deletions llmkira/sdk/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ def get_by_model_name(self,
f"please check your model name"
)

def get_model_list(self):
return [model.model_name for model in self.model_list]

def get_token_limit(self,
*,
model_name: str
Expand Down
17 changes: 0 additions & 17 deletions llmkira/sdk/endpoint/openai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,6 @@

load_dotenv()

MODEL = Literal[
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-4-0613",
"gpt-4",
"gpt-4-32k",
"gpt-4-32k-0613",
# "gpt-3.5-turbo-instruct",
# "gpt-4-0314",
# "gpt-3.5-turbo-0301",
# "gpt-4-32k-0314"
# Do not use 0314. See:
# https://platform.openai.com/docs/guides/gpt/function-calling
]


class OpenaiResult(LlmResult):
class Usage(BaseModel):
Expand Down

0 comments on commit d785885

Please sign in to comment.