diff --git a/llmkira/extra/user/__init__.py b/llmkira/extra/user/__init__.py index da14b8198..516f3494c 100644 --- a/llmkira/extra/user/__init__.py +++ b/llmkira/extra/user/__init__.py @@ -6,8 +6,9 @@ from typing import List, Union, Optional from urllib.parse import urlparse -from llmkira.sdk.endpoint.openai import MODEL +from llmkira.sdk.adapter import SCHEMA_GROUP from llmkira.sdk.func_calling import ToolRegister + from .client import UserCostClient, UserConfigClient, UserCost, UserConfig from .schema import UserDriverMode from ...sdk.endpoint import Driver @@ -40,7 +41,7 @@ async def get_cost_by_uid(uid: str) -> List[UserCost]: class UserControl(object): @staticmethod def get_model(): - return MODEL.__args__ + return SCHEMA_GROUP.get_model_list() @staticmethod async def get_driver_config( @@ -105,8 +106,8 @@ async def set_endpoint( :return: new_driver """ # assert model in MODEL.__args__, f"openai model is not valid,must be one of {MODEL.__args__}" - if model not in MODEL.__args__: - model = MODEL.__args__[0] + if model not in UserControl.get_model(): + model = UserControl.get_model()[0] _user_data = await UserConfigClient().read_by_uid(uid=uid) _user_data = _user_data or UserConfig(uid=uid) new_driver = Driver(endpoint=endpoint, api_key=api_key, model=model, org_id=org_id) diff --git a/llmkira/sdk/adapter.py b/llmkira/sdk/adapter.py index 76ff0845e..05455c646 100644 --- a/llmkira/sdk/adapter.py +++ b/llmkira/sdk/adapter.py @@ -53,6 +53,9 @@ def get_by_model_name(self, f"please check your model name" ) + def get_model_list(self): + return [model.model_name for model in self.model_list] + def get_token_limit(self, *, model_name: str diff --git a/llmkira/sdk/endpoint/openai/__init__.py b/llmkira/sdk/endpoint/openai/__init__.py index 7f7741cbd..ae927cdba 100644 --- a/llmkira/sdk/endpoint/openai/__init__.py +++ b/llmkira/sdk/endpoint/openai/__init__.py @@ -23,23 +23,6 @@ load_dotenv() -MODEL = Literal[ - "gpt-3.5-turbo", - "gpt-3.5-turbo-16k", - "gpt-3.5-turbo-0613", - "gpt-3.5-turbo-16k-0613", - "gpt-4-0613", - "gpt-4", - "gpt-4-32k", - "gpt-4-32k-0613", - # "gpt-3.5-turbo-instruct", - # "gpt-4-0314", - # "gpt-3.5-turbo-0301", - # "gpt-4-32k-0314" - # Do not use 0314. See: - # https://platform.openai.com/docs/guides/gpt/function-calling -] - class OpenaiResult(LlmResult): class Usage(BaseModel):