Skip to content

Commit

Permalink
feat[resource]: supported dynamic infer api model (#361)
Browse files Browse the repository at this point in the history
* feat[resource]: supported dynamic infer api model

* feat[trainer]: support txt2img training

* feat[resource]: fix lint

* fix[resources]: lint

* fix[resources]: fix configs and client show

* fix[resources]: fix ut
  • Loading branch information
danielhjz authored Mar 21, 2024
1 parent e107f9d commit e576b18
Show file tree
Hide file tree
Showing 20 changed files with 274 additions and 79 deletions.
2 changes: 1 addition & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "qianfan"
version = "0.3.5"
version = "0.3.6"
description = "文心千帆大模型平台 Python SDK"
authors = []
license = "Apache-2.0"
Expand Down
6 changes: 3 additions & 3 deletions python/qianfan/common/client/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,12 +155,12 @@ def list_model_callback(
models = t.models()
for m in sorted(models):
info = t.get_model_info(m)
if not info.depracated:
if not info.deprecated:
console.print(m, highlight=False)
for m in sorted(models):
info = t.get_model_info(m)
if info.depracated:
console.print(f"[s]{m} [dim](depracated)[/]", highlight=False)
if info.deprecated:
console.print(f"[s]{m} [dim](deprecated)[/]", highlight=False)
raise typer.Exit()


Expand Down
3 changes: 3 additions & 0 deletions python/qianfan/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ class Config:
ACCESS_TOKEN_REFRESH_MIN_INTERVAL: float = Field(
default=DefaultValue.AccessTokenRefreshMinInterval
)
INFER_RESOURCE_REFRESH_INTERVAL: float = Field(
default=DefaultValue.InferResourceRefreshMinInterval
)
QPS_LIMIT: float = Field(default=DefaultValue.QpsLimit)
RPM_LIMIT: float = Field(default=DefaultValue.RpmLimit)
TPM_LIMIT: int = Field(default=DefaultValue.TpmLimit)
Expand Down
2 changes: 2 additions & 0 deletions python/qianfan/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ class Env:
IAMSignExpirationSeconds: str = "QIANFAN_IAM_SIGN_EXPIRATION_SEC"
ConsoleAPIBaseURL: str = "QIANFAN_CONSOLE_API_BASE_URL"
AccessTokenRefreshMinInterval: str = "QIANFAN_ACCESS_TOKEN_REFRESH_MIN_INTERVAL"
InferResourceRefreshMinInterval: str = "QIANFAN_INFER_RESOURCE_REFRESH_MIN_INTERVAL"
EnablePrivate: str = "QIANFAN_ENABLE_PRIVATE"
AccessCode: str = "QIANFAN_PRIVATE_ACCESS_CODE"
QpsLimit: str = "QIANFAN_QPS_LIMIT"
Expand Down Expand Up @@ -115,6 +116,7 @@ class DefaultValue:
IAMSignExpirationSeconds: int = 300
ConsoleAPIBaseURL: str = "https://qianfan.baidubce.com"
AccessTokenRefreshMinInterval: float = 3600
InferResourceRefreshMinInterval: float = 600
RetryCount: int = 3
RetryTimeout: float = 60
RetryBackoffFactor: float = 1
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

class QianfanRequestSettings(AIRequestSettings):
ai_model_id: Optional[str] = Field(None, serialization_alias="model")
temperature: float = Field(0.95, g=0.0, le=1.0)
temperature: float = Field(0.95, gt=0.0, le=1.0)
top_p: Optional[float] = None
top_k: Optional[int] = None
stream: Optional[bool] = None
Expand Down
54 changes: 36 additions & 18 deletions python/qianfan/resources/images/image2text.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]:
a dict which key is preset model and value is the endpoint
"""
return {
info_list = {
UNSPECIFIED_MODEL: QfLLMInfo(
endpoint="",
# the key of api is "query", which is conflict with query in params
Expand All @@ -55,6 +55,20 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]:
},
),
}
# 获取最新的模型列表
latest_models_list = super()._supported_models()
for m in latest_models_list:
if m not in info_list:
info_list[m] = latest_models_list[m]
else:
# 更新endpoint
info_list[m].endpoint = latest_models_list[m].endpoint

return info_list

@classmethod
def api_type(cls) -> str:
return "image2text"

@classmethod
def _default_model(self) -> str:
Expand Down Expand Up @@ -87,30 +101,34 @@ def _generate_body(
for key in IGNORED_KEYS:
if key in kwargs:
del kwargs[key]
if model is not None and model in self._supported_models():
model_info = self._supported_models()[model]
if model is not None and self.get_model_info(model):
model_info = self.get_model_info(model)
# warn if user provide unexpected arguments
for key in kwargs:
if (
key not in model_info.required_keys
and key not in model_info.optional_keys
):
log_warn(
f"This key `{key}` does not seem to be a parameter that the"
f" model `{model}` will accept"
)
if model_info.deprecated:
# 动态获取的模型暂时不做字段校验:
for key in kwargs:
if (
key not in model_info.required_keys
and key not in model_info.optional_keys
):
log_warn(
f"This key `{key}` does not seem to be a parameter that the"
f" model `{model}` will accept"
)
else:
default_model_info = self._supported_models()[self._default_model()]
default_model_info = self.get_model_info(self._default_model())
if endpoint == default_model_info.endpoint:
model_info = default_model_info
else:
model_info = self._supported_models()[UNSPECIFIED_MODEL]

for key in model_info.required_keys:
if key not in kwargs:
raise errors.ArgumentNotFoundError(
f"The required key `{key}` is not provided."
)
if model_info.deprecated:
# 动态获取的模型暂时不做字段校验:
for key in model_info.required_keys:
if key not in kwargs:
raise errors.ArgumentNotFoundError(
f"The required key `{key}` is not provided."
)
if stream is True:
kwargs["stream"] = True
return kwargs
Expand Down
16 changes: 15 additions & 1 deletion python/qianfan/resources/images/text2image.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]:
a dict which key is preset model and value is the endpoint
"""
return {
info_list = {
"Stable-Diffusion-XL": QfLLMInfo(
endpoint="/text2image/sd_xl",
required_keys={"prompt"},
Expand All @@ -72,6 +72,20 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]:
},
),
}
# 获取最新的模型列表
latest_models_list = super()._supported_models()
for m in latest_models_list:
if m not in info_list:
info_list[m] = latest_models_list[m]
else:
# 更新endpoint
info_list[m].endpoint = latest_models_list[m].endpoint

return info_list

@classmethod
def api_type(cls) -> str:
return "text2image"

@classmethod
def _default_model(self) -> str:
Expand Down
147 changes: 118 additions & 29 deletions python/qianfan/resources/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import copy
import threading
from concurrent.futures import Future, ThreadPoolExecutor
from datetime import MINYEAR, datetime
from typing import (
Any,
AsyncIterator,
Expand All @@ -37,6 +38,7 @@
import qianfan.errors as errors
from qianfan import get_config
from qianfan.consts import Consts, DefaultValue
from qianfan.resources.console.service import Service
from qianfan.resources.requestor.openapi_requestor import create_api_requestor
from qianfan.resources.typing import JsonBody, QfLLMInfo, QfResponse, RetryConfig
from qianfan.utils import log_info, log_warn, utils
Expand Down Expand Up @@ -173,7 +175,7 @@ def _update_model_and_endpoint(
endpoint = self._endpoint
if endpoint is None:
model_name = self._default_model() if model is None else model
model_info = self._supported_models().get(model_name, None)
model_info = self.get_model_info(model_name)
if model_info is None:
raise errors.InvalidArgumentError(
f"The provided model `{model}` is not in the list of supported"
Expand Down Expand Up @@ -371,6 +373,10 @@ def _check_params(
if stream is True and retry_count != 1:
log_warn("retry is not available when stream is enabled")

@classmethod
def api_type(cls) -> str:
return ""

@classmethod
def _supported_models(cls) -> Dict[str, QfLLMInfo]:
"""
Expand All @@ -383,7 +389,11 @@ def _supported_models(cls) -> Dict[str, QfLLMInfo]:
a dict which key is preset model and value is the endpoint
"""
raise NotImplementedError
# raise NotImplementedError
if cls.api_type == "":
return {}
else:
return get_latest_supported_models().get(cls.api_type(), {})

@classmethod
def _default_model(cls) -> str:
Expand All @@ -409,8 +419,10 @@ def get_model_info(cls, model: str) -> QfLLMInfo:
Return:
Information of the model
"""
model_info = cls._supported_models().get(model)
model_info_list = {k.lower(): v for k, v in cls._supported_models().items()}
model_info = model_info_list.get(model.lower())
if model_info is None:
# 拿不到的话
raise errors.InvalidArgumentError(
f"The provided model `{model}` is not in the list of supported models."
" If this is a recently added model, try using the `endpoint`"
Expand All @@ -433,16 +445,14 @@ def _get_endpoint(self, model: str) -> QfLLMInfo:
Raises:
QianfanError: if the input is not in self._supported_models()
"""
if model not in self._supported_models():
try:
model_info = self.get_model_info(model)
except errors.InvalidArgumentError:
if self._endpoint is not None:
return QfLLMInfo(endpoint=self._endpoint)
raise errors.InvalidArgumentError(
f"The provided model `{model}` is not in the list of supported models."
" If this is a recently added model, try using the `endpoint`"
" arguments and create an issue to tell us. Supported models:"
f" {self.models()}"
)
return self._supported_models()[model]
else:
raise
return model_info

def _get_endpoint_from_dict(
self, model: Optional[str], endpoint: Optional[str], stream: bool, **kwargs: Any
Expand Down Expand Up @@ -503,25 +513,35 @@ def _generate_body(
for key in IGNORED_KEYS:
if key in kwargs:
del kwargs[key]
if model is not None and model in self._supported_models():
model_info = self._supported_models()[model]
# warn if user provide unexpected arguments
for key in kwargs:
if (
key not in model_info.required_keys
and key not in model_info.optional_keys
):
log_warn(
f"This key `{key}` does not seem to be a parameter that the"
f" model `{model}` will accept"
)
else:
default_model_info = self._supported_models()[self._default_model()]
if endpoint == default_model_info.endpoint:
model_info = default_model_info
else:
model_info = self._supported_models()[UNSPECIFIED_MODEL]
model_info: Optional[QfLLMInfo] = None
if model is not None:
try:
model_info = self.get_model_info(model)
# warn if user provide unexpected arguments
for key in kwargs:
if (
key not in model_info.required_keys
and key not in model_info.optional_keys
):
log_warn(
f"This key `{key}` does not seem to be a parameter that the"
f" model `{model}` will accept"
)
except errors.InvalidArgumentError:
...

if model_info is None:
# 使用默认模型
try:
default_model_info = self.get_model_info(self._default_model())
if default_model_info.endpoint == endpoint:
model_info = default_model_info
except errors.InvalidArgumentError:
...

# 非默认模型
if model_info is None:
model_info = self._supported_models()[UNSPECIFIED_MODEL]
for key in model_info.required_keys:
if key not in kwargs:
raise errors.ArgumentNotFoundError(
Expand Down Expand Up @@ -602,3 +622,72 @@ async def _with_concurrency_limit(
*[asyncio.ensure_future(_with_concurrency_limit(task)) for task in tasks],
return_exceptions=True,
)


# {api_type: {model_name: QfLLMInfo}}
_runtime_models_info: Dict[str, Dict[str, QfLLMInfo]] = {}
_last_update_time: datetime = datetime(MINYEAR, 1, 1)
_model_infos_access_lock: threading.Lock = threading.Lock()


def trim_prefix(s: str, prefix: str) -> str:
if s.startswith(prefix):
return s[len(prefix) :]
else:
return s


def get_latest_supported_models() -> Dict[str, Dict[str, QfLLMInfo]]:
"""
fetch supported models from server
and update the `_runtime_models_info`
"""
if get_config().ACCESS_KEY is None or get_config().SECRET_KEY is None:
return {}

if get_config().ENABLE_PRIVATE:
# 私有化直接跳过
return {}

global _last_update_time
global _runtime_models_info
if (
datetime.now() - _last_update_time
).total_seconds() > get_config().ACCESS_TOKEN_REFRESH_MIN_INTERVAL:
_model_infos_access_lock.acquire()
if (
datetime.now() - _last_update_time
).total_seconds() < get_config().ACCESS_TOKEN_REFRESH_MIN_INTERVAL:
_model_infos_access_lock.release()
return _runtime_models_info
try:
svc_list = Service.list()["result"]["common"]
except Exception as e:
log_warn(f"fetch_supported_models failed: {e}")
_model_infos_access_lock.release()
_last_update_time = datetime.now()
return _runtime_models_info

# get preset services:
for s in svc_list:
[api_type, model_endpoint] = trim_prefix(
s["url"],
"{}{}/".format(
DefaultValue.BaseURL,
Consts.ModelAPIPrefix,
),
).split("/")
model_info = _runtime_models_info.get(api_type)
if model_info is None:
model_info = {}
model_info[s["name"]] = QfLLMInfo(
endpoint="/{}/{}".format(api_type, model_endpoint),
api_type=api_type,
)
_runtime_models_info[api_type] = model_info
_last_update_time = datetime.now()
_model_infos_access_lock.release()
return _runtime_models_info


get_latest_supported_models()
Loading

0 comments on commit e576b18

Please sign in to comment.