Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: compatible with pydantic v1 #180

Merged
merged 12 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
[![Release Notes](https://img.shields.io/github/release/baidubce/bce-qianfan-sdk)](https://github.com/baidubce/bce-qianfan-sdk/releases)
[![PyPI version](https://badge.fury.io/py/qianfan.svg)](https://pypi.org/project/qianfan/)
[![Documentation Status](https://readthedocs.org/projects/qianfan/badge/?version=stable)](https://qianfan.readthedocs.io/en/stable/README.html)
[![Feedback Issue](https://img.shields.io/badge/%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC-GitHub_Issue-brightgreen)](https://github.com/baidubce/bce-qianfan-sdk/issues)
[![Feedback Ticket](https://img.shields.io/badge/%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC-%E7%99%BE%E5%BA%A6%E6%99%BA%E8%83%BD%E4%BA%91%E5%B7%A5%E5%8D%95-brightgreen)](https://console.bce.baidu.com/ticket/#/ticket/create?productId=279)

## 简介

Expand Down Expand Up @@ -171,7 +173,7 @@ trainer.run()
> Check [**API References**](https://qianfan.readthedocs.io/en/stable/qianfan.html) for more details.


### 联系我们
## 联系我们
如使用过程中遇到什么问题,或对SDK功能有建议,可通过如下方式联系我们
- [GitHub issues](https://github.com/baidubce/bce-qianfan-sdk/issues)
- [百度智能云工单](https://console.bce.baidu.com/ticket/#/ticket/create?productId=279) (百度专家即时服务)
Expand Down
11 changes: 9 additions & 2 deletions README.pypi.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
[![Release Notes](https://img.shields.io/github/release/baidubce/bce-qianfan-sdk)](https://github.com/baidubce/bce-qianfan-sdk/releases)
[![PyPI version](https://badge.fury.io/py/qianfan.svg)](https://pypi.org/project/qianfan/)
[![Documentation Status](https://readthedocs.org/projects/qianfan/badge/?version=stable)](https://qianfan.readthedocs.io/en/stable/README.html)
[![Feedback Issue](https://img.shields.io/badge/%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC-GitHub_Issue-brightgreen)](https://github.com/baidubce/bce-qianfan-sdk/issues)
[![Feedback Ticket](https://img.shields.io/badge/%E8%81%94%E7%B3%BB%E6%88%91%E4%BB%AC-%E7%99%BE%E5%BA%A6%E6%99%BA%E8%83%BD%E4%BA%91%E5%B7%A5%E5%8D%95-brightgreen)](https://console.bce.baidu.com/ticket/#/ticket/create?productId=279)

[Documentation](https://qianfan.readthedocs.io/en/stable/README.html) | [GitHub](https://github.com/baidubce/bce-qianfan-sdk) | [Cookbook](https://github.com/baidubce/bce-qianfan-sdk/tree/main/cookbook)

Expand Down Expand Up @@ -108,10 +110,15 @@ print(resp["result"])
- Tokenizer [[Doc](https://qianfan.readthedocs.io/en/stable/docs/utils.html)][[GitHub](https://github.com/baidubce/bce-qianfan-sdk/blob/main/docs/utils.md)]
- 接口流控 [[Doc](https://qianfan.readthedocs.io/en/stable/docs/configurable.html)][[GitHub](https://github.com/baidubce/bce-qianfan-sdk/blob/main/docs/configurable.md)]

详细信息请参考相应的文档。如果有任何问题,欢迎前往 [GitHub](https://github.com/baidubce/bce-qianfan-sdk) 提交 issue。

> 还可以通过 [**API References**](https://qianfan.readthedocs.io/en/stable/qianfan.html) 查看每个接口的详细说明。

## 联系我们

如使用过程中遇到什么问题,或对SDK功能有建议,可通过如下方式联系我们

- [GitHub issues](https://github.com/baidubce/bce-qianfan-sdk/issues)
- [百度智能云工单](https://console.bce.baidu.com/ticket/#/ticket/create?productId=279) (百度专家即时服务)

## License

Apache-2.0
7 changes: 4 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ aiolimiter = ">=1.1.0"
importlib-metadata = { version = ">=1.4.0", python = "<=3.7" }
bce-python-sdk = ">=0.8.79"
typing-extensions = { version = ">=4.0.0", python = "<=3.10" }
pydantic = ">=2"
pydantic-settings = ">=2.0.3"
pydantic = "*"
python-dotenv = "<=0.21.1"
langchain = { version = ">=0.0.321", python = ">=3.8.1", optional = true }
numpy = [
{ version = "<1.22.0", python = ">=3.7 <3.8" },
Expand Down Expand Up @@ -53,6 +53,7 @@ sphinx-rtd-theme = ">=1.2.0"
mypy = ">=1.4.0"
myst-parser = ">=0.19.2"
pytest-mock = "3.11.1"
types-protobuf = "4.24.0.4"

[tool.poetry.extras]
langchain = ["langchain"]
Expand All @@ -72,7 +73,7 @@ preview = true
[tool.mypy]
ignore_missing_imports = "True"
disallow_untyped_defs = "True"
exclude = ["qianfan/tests"]
exclude = ["qianfan/tests", "qianfan/pydantic"]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是避免什么问题



[build-system]
Expand Down
9 changes: 4 additions & 5 deletions src/qianfan/common/tool/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,19 @@
import string
from typing import Any, Dict, List, Optional, Type

from pydantic import BaseModel
from pydantic.v1 import (
from qianfan.pydantic import (
BaseModel as PydanticV1BaseModel,
)
from pydantic.v1 import (
from qianfan.pydantic import (
Field as PydanticV1Field,
)
from pydantic.v1 import (
from qianfan.pydantic import (
create_model as create_pydantic_v1_model,
)
from qianfan.utils.utils import assert_package_installed


class ToolParameter(BaseModel):
class ToolParameter(PydanticV1BaseModel):
"""
Tool parameters, used to define the inputs when calling a tool and
to describe the parameters needed when calling the tool to the model.
Expand Down
9 changes: 5 additions & 4 deletions src/qianfan/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@
import os
from typing import Optional

from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing_extensions import deprecated

from qianfan.consts import DefaultValue, Env
from qianfan.pydantic import BaseSettings, Field


class GlobalConfig(BaseSettings):
"""
The global config of whole qianfan sdk
"""

model_config = SettingsConfigDict(env_prefix="QIANFAN_", case_sensitive=True)
class Config:
env_file_encoding = "utf-8"
env_prefix = "QIANFAN_"
case_sensitive = True
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BaseSettings内部类Config用于设置BaseSetting的一些元信息和行为选项


AK: Optional[str] = Field(default=None)
SK: Optional[str] = Field(default=None)
Expand All @@ -47,7 +49,6 @@ class GlobalConfig(BaseSettings):

# for private
ENABLE_PRIVATE: bool = Field(default=DefaultValue.EnablePrivate)
# todo 补充 ENABLE_AUTH 的默认值和使用方法
ENABLE_AUTH: Optional[bool] = Field(default=None)
ACCESS_CODE: Optional[str] = Field(default=None)
IMPORT_STATUS_POLLING_INTERVAL: float = Field(
Expand Down
2 changes: 1 addition & 1 deletion src/qianfan/dataset/data_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""


from pydantic import BaseModel, Field
from qianfan.pydantic import BaseModel, Field


class QianfanOperator(BaseModel):
Expand Down
26 changes: 14 additions & 12 deletions src/qianfan/dataset/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@

import dateutil.parser
import requests
from pydantic import BaseModel, Field, model_validator

from qianfan.config import get_config
from qianfan.dataset.consts import QianfanDatasetLocalCacheDir
from qianfan.errors import FileSizeOverflow, QianfanRequestError
from qianfan.pydantic import BaseModel, Field, root_validator
from qianfan.resources.console.consts import (
DataExportDestinationType,
DataExportStatus,
Expand Down Expand Up @@ -261,24 +261,26 @@ def set_format_type(self, format_type: FormatType) -> None:
"""
self.file_format = format_type

@model_validator(mode="after")
def _format_check(self) -> "FileDataSource":
if self.file_format:
return self
@root_validator
def _format_check(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values["file_format"]:
return values

path = values["path"]

try:
index = self.path.rfind(".")
index = path.rfind(".")
# 读文件夹或查询不到的情况下默认使用纯文本格式
if os.path.isdir(self.path) or index == -1:
if os.path.isdir(path) or index == -1:
log_warn(f"use default format type {FormatType.Text}")
self.file_format = FormatType.Text
return self
suffix = self.path[index + 1 :]
values["file_format"] = FormatType.Text
return values
suffix = path[index + 1 :]
for t in FormatType:
if t.value == suffix:
self.file_format = t
values["file_format"] = t
log_info(f"use format type {t}")
return self
return values
raise ValueError(f"cannot match proper format type for {suffix}")
except Exception as e:
log_error(str(e))
Expand Down
2 changes: 1 addition & 1 deletion src/qianfan/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ def start_online_data_process_task(self, operators: List[QianfanOperator]) -> in
"desensitization": [],
}
for operator in operators:
attr_dict = operator.model_dump()
attr_dict = operator.dict()
attr_dict.pop("operator_name")
attr_dict.pop("operator_type")

Expand Down
2 changes: 1 addition & 1 deletion src/qianfan/dataset/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import pyarrow
import pyarrow.compute as pc
from pyarrow import Table as PyarrowTable
from pydantic import BaseModel
from typing_extensions import Self

from qianfan.dataset.consts import (
Expand All @@ -32,6 +31,7 @@
Processable,
)
from qianfan.dataset.table_utils import _construct_table_from_nest_sequence
from qianfan.pydantic import BaseModel
from qianfan.utils import log_debug, log_error, log_info, log_warn


Expand Down
5 changes: 2 additions & 3 deletions src/qianfan/evaluation/evaluation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from concurrent.futures import ALL_COMPLETED, Future, ThreadPoolExecutor, wait
from typing import Any, Dict, List, Optional, Set, Union

from pydantic import BaseModel, Field, model_validator

from qianfan import get_config
from qianfan.dataset import Dataset, QianfanDataSource
from qianfan.dataset.consts import (
Expand All @@ -43,6 +41,7 @@
QianfanRuleEvaluator,
)
from qianfan.model import Model, Service
from qianfan.pydantic import BaseModel, Field, root_validator
from qianfan.resources import Model as ResourceModel
from qianfan.resources.console.consts import EvaluationTaskStatus
from qianfan.utils import log_debug, log_error, log_info, log_warn
Expand All @@ -55,7 +54,7 @@ class EvaluationManager(BaseModel):
local_evaluators: Optional[List[LocalEvaluator]] = Field(default=None)
qianfan_evaluators: Optional[List[QianfanEvaluator]] = Field(default=None)

@model_validator(mode="before")
@root_validator
@classmethod
def _check_evaluators(cls, input_dict: Any) -> Any:
"""校验传入的参数"""
Expand Down
18 changes: 10 additions & 8 deletions src/qianfan/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Field, model_validator

from qianfan.evaluation.consts import (
QianfanRefereeEvaluatorDefaultMaxScore,
QianfanRefereeEvaluatorDefaultMetrics,
QianfanRefereeEvaluatorDefaultSteps,
)
from qianfan.pydantic import BaseModel, Field, root_validator
from qianfan.utils import log_error, log_warn


Expand Down Expand Up @@ -85,7 +84,7 @@ class QianfanManualEvaluator(QianfanEvaluator):
default=[ManualEvaluatorDimension(dimension="满意度")]
)

@model_validator(mode="before")
@root_validator
@classmethod
def dimension_validation(cls, input_dict: Any) -> Any:
assert isinstance(input_dict, dict)
Expand Down Expand Up @@ -120,16 +119,19 @@ class Config:

open_compass_evaluator: BaseEvaluator

@model_validator(mode="after")
def _check_open_compass_evaluator(self) -> "OpenCompassLocalEvaluator":
signature = inspect.signature(self.open_compass_evaluator.score)
@root_validator
def _check_open_compass_evaluator(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
open_compass_evaluator = values["open_compass_evaluator"]
signature = inspect.signature(open_compass_evaluator.score)
params = list(signature.parameters.keys())
params.sort()
if params != ["predictions", "references"]:
raise ValueError(
f"unsupported opencompass evaluator {self.open_compass_evaluator}"
f"unsupported opencompass evaluator {type(open_compass_evaluator)}"
)
return self
return values

def evaluate(
self, input: Union[str, List[Dict[str, Any]]], reference: str, output: str
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def _agent_input_keys() -> List[str]:

def _agent_validate_logical_core(values: dict) -> dict:
"""check if llm is valid"""
if not isinstance(values["llm"], QianfanChatEndpoint):
if not isinstance(values["llm"], QianfanChatEndpoint): # type: ignore
raise ValueError("Only supported with QianfanChatEndpoint models.")
if not (values["llm"].model == "ERNIE-Bot" or values["llm"].model == "ERNIE-Bot-4"):
raise ValueError(
Expand Down
3 changes: 1 addition & 2 deletions src/qianfan/model/configs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any

from pydantic import BaseModel

from qianfan.model.consts import ServiceType
from qianfan.pydantic import BaseModel
from qianfan.resources.console import consts as console_consts


Expand Down
22 changes: 22 additions & 0 deletions src/qianfan/pydantic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Pydantic compatible layer
"""

try:
from pydantic.v1 import * # noqa
Copy link
Collaborator

@danielhjz danielhjz Jan 10, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

使用v1 api

except ImportError:
from pydantic import * # noqa
11 changes: 11 additions & 0 deletions src/qianfan/tests/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""
Unit test for config
"""
import os

from qianfan import get_config
from qianfan.consts import DefaultValue

Expand All @@ -24,3 +26,12 @@ def test_rewrite_config_through_code():
assert get_config().AUTH_TIMEOUT == DefaultValue.AuthTimeout
config_center.AUTH_TIMEOUT = 114514
assert get_config().AUTH_TIMEOUT == 114514


def test_read_from_dot_env():
try:
with open(".env", "w") as f:
f.write('QIANFAN_ACCESS_TOKEN="test_token"')
assert get_config().ACCESS_TOKEN == "test_token"
finally:
os.remove(".env")
2 changes: 1 addition & 1 deletion src/qianfan/tests/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from unittest.mock import patch

import pytest
from pydantic import BaseModel

from qianfan.dataset.consts import QianfanDataGroupColumnName
from qianfan.dataset.data_operator import FilterCheckNumberWords
Expand All @@ -30,6 +29,7 @@
QianfanNonSortedConversation,
QianfanSortedConversation,
)
from qianfan.pydantic import BaseModel
from qianfan.resources.console.consts import DataTemplateType


Expand Down
6 changes: 4 additions & 2 deletions src/qianfan/tests/tool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def test_tool_from_langchain_tool():
return

from langchain.tools.base import BaseTool as LangchainBaseTool
from pydantic.v1 import BaseModel, Field

from qianfan.pydantic import BaseModel, Field

class CalculatorToolSchema(BaseModel):
a: int = Field(description="a description")
Expand Down Expand Up @@ -145,7 +146,8 @@ def test_tool_from_langchain_func_tool():
return

from langchain.tools.base import Tool as LangchainTool
from pydantic.v1 import BaseModel, Field

from qianfan.pydantic import BaseModel, Field

def hello(a: str, b: str) -> str:
return f"hello {a} {b}"
Expand Down
Loading
Loading