Skip to content

Commit

Permalink
feat: compatible with pydantic v1 (#180)
Browse files Browse the repository at this point in the history
* readme: add feedback badge

* update badge

* update badge

* update pypi readme

* switch pydantic-settings to dynaconf

* fix ci

* pydantic compatibility

* format

* rollback

* format

* fix lint

---------

Co-authored-by: ZingLix <[email protected]>
  • Loading branch information
Dobiichi-Origami and ZingLix authored Jan 10, 2024
1 parent 9eeb914 commit 32a89b3
Show file tree
Hide file tree
Showing 15 changed files with 81 additions and 45 deletions.
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ aiolimiter = ">=1.1.0"
importlib-metadata = { version = ">=1.4.0", python = "<=3.7" }
bce-python-sdk = ">=0.8.79"
typing-extensions = { version = ">=4.0.0", python = "<=3.10" }
pydantic = ">=2"
pydantic-settings = ">=2.0.3"
pydantic = "*"
python-dotenv = "<=0.21.1"
langchain = { version = ">=0.0.321", python = ">=3.8.1", optional = true }
numpy = [
{ version = "<1.22.0", python = ">=3.7 <3.8" },
Expand Down Expand Up @@ -73,7 +73,7 @@ preview = true
[tool.mypy]
ignore_missing_imports = "True"
disallow_untyped_defs = "True"
exclude = ["qianfan/tests"]
exclude = ["qianfan/tests", "qianfan/pydantic"]


[build-system]
Expand Down
9 changes: 4 additions & 5 deletions src/qianfan/common/tool/base_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,19 @@
import string
from typing import Any, Dict, List, Optional, Type

from pydantic import BaseModel
from pydantic.v1 import (
from qianfan.pydantic import (
BaseModel as PydanticV1BaseModel,
)
from pydantic.v1 import (
from qianfan.pydantic import (
Field as PydanticV1Field,
)
from pydantic.v1 import (
from qianfan.pydantic import (
create_model as create_pydantic_v1_model,
)
from qianfan.utils.utils import assert_package_installed


class ToolParameter(BaseModel):
class ToolParameter(PydanticV1BaseModel):
"""
Tool parameters, used to define the inputs when calling a tool and
to describe the parameters needed when calling the tool to the model.
Expand Down
9 changes: 5 additions & 4 deletions src/qianfan/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,21 @@
import os
from typing import Optional

from pydantic import Field
from pydantic_settings import BaseSettings, SettingsConfigDict
from typing_extensions import deprecated

from qianfan.consts import DefaultValue, Env
from qianfan.pydantic import BaseSettings, Field


class GlobalConfig(BaseSettings):
"""
The global config of whole qianfan sdk
"""

model_config = SettingsConfigDict(env_prefix="QIANFAN_", case_sensitive=True)
class Config:
env_file_encoding = "utf-8"
env_prefix = "QIANFAN_"
case_sensitive = True

AK: Optional[str] = Field(default=None)
SK: Optional[str] = Field(default=None)
Expand All @@ -47,7 +49,6 @@ class GlobalConfig(BaseSettings):

# for private
ENABLE_PRIVATE: bool = Field(default=DefaultValue.EnablePrivate)
# todo 补充 ENABLE_AUTH 的默认值和使用方法
ENABLE_AUTH: Optional[bool] = Field(default=None)
ACCESS_CODE: Optional[str] = Field(default=None)
IMPORT_STATUS_POLLING_INTERVAL: float = Field(
Expand Down
2 changes: 1 addition & 1 deletion src/qianfan/dataset/data_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""


from pydantic import BaseModel, Field
from qianfan.pydantic import BaseModel, Field


class QianfanOperator(BaseModel):
Expand Down
26 changes: 14 additions & 12 deletions src/qianfan/dataset/data_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@

import dateutil.parser
import requests
from pydantic import BaseModel, Field, model_validator

from qianfan.config import get_config
from qianfan.dataset.consts import QianfanDatasetLocalCacheDir
from qianfan.errors import FileSizeOverflow, QianfanRequestError
from qianfan.pydantic import BaseModel, Field, root_validator
from qianfan.resources.console.consts import (
DataExportDestinationType,
DataExportStatus,
Expand Down Expand Up @@ -261,24 +261,26 @@ def set_format_type(self, format_type: FormatType) -> None:
"""
self.file_format = format_type

@model_validator(mode="after")
def _format_check(self) -> "FileDataSource":
if self.file_format:
return self
@root_validator
def _format_check(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if values["file_format"]:
return values

path = values["path"]

try:
index = self.path.rfind(".")
index = path.rfind(".")
# 读文件夹或查询不到的情况下默认使用纯文本格式
if os.path.isdir(self.path) or index == -1:
if os.path.isdir(path) or index == -1:
log_warn(f"use default format type {FormatType.Text}")
self.file_format = FormatType.Text
return self
suffix = self.path[index + 1 :]
values["file_format"] = FormatType.Text
return values
suffix = path[index + 1 :]
for t in FormatType:
if t.value == suffix:
self.file_format = t
values["file_format"] = t
log_info(f"use format type {t}")
return self
return values
raise ValueError(f"cannot match proper format type for {suffix}")
except Exception as e:
log_error(str(e))
Expand Down
2 changes: 1 addition & 1 deletion src/qianfan/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,7 +690,7 @@ def start_online_data_process_task(self, operators: List[QianfanOperator]) -> in
"desensitization": [],
}
for operator in operators:
attr_dict = operator.model_dump()
attr_dict = operator.dict()
attr_dict.pop("operator_name")
attr_dict.pop("operator_type")

Expand Down
2 changes: 1 addition & 1 deletion src/qianfan/dataset/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import pyarrow
import pyarrow.compute as pc
from pyarrow import Table as PyarrowTable
from pydantic import BaseModel
from typing_extensions import Self

from qianfan.dataset.consts import (
Expand All @@ -32,6 +31,7 @@
Processable,
)
from qianfan.dataset.table_utils import _construct_table_from_nest_sequence
from qianfan.pydantic import BaseModel
from qianfan.utils import log_debug, log_error, log_info, log_warn


Expand Down
5 changes: 2 additions & 3 deletions src/qianfan/evaluation/evaluation_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
from concurrent.futures import ALL_COMPLETED, Future, ThreadPoolExecutor, wait
from typing import Any, Dict, List, Optional, Set, Union

from pydantic import BaseModel, Field, model_validator

from qianfan import get_config
from qianfan.dataset import Dataset, QianfanDataSource
from qianfan.dataset.consts import (
Expand All @@ -43,6 +41,7 @@
QianfanRuleEvaluator,
)
from qianfan.model import Model, Service
from qianfan.pydantic import BaseModel, Field, root_validator
from qianfan.resources import Model as ResourceModel
from qianfan.resources.console.consts import EvaluationTaskStatus
from qianfan.utils import log_debug, log_error, log_info, log_warn
Expand All @@ -55,7 +54,7 @@ class EvaluationManager(BaseModel):
local_evaluators: Optional[List[LocalEvaluator]] = Field(default=None)
qianfan_evaluators: Optional[List[QianfanEvaluator]] = Field(default=None)

@model_validator(mode="before")
@root_validator
@classmethod
def _check_evaluators(cls, input_dict: Any) -> Any:
"""校验传入的参数"""
Expand Down
18 changes: 10 additions & 8 deletions src/qianfan/evaluation/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,12 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Field, model_validator

from qianfan.evaluation.consts import (
QianfanRefereeEvaluatorDefaultMaxScore,
QianfanRefereeEvaluatorDefaultMetrics,
QianfanRefereeEvaluatorDefaultSteps,
)
from qianfan.pydantic import BaseModel, Field, root_validator
from qianfan.utils import log_error, log_warn


Expand Down Expand Up @@ -85,7 +84,7 @@ class QianfanManualEvaluator(QianfanEvaluator):
default=[ManualEvaluatorDimension(dimension="满意度")]
)

@model_validator(mode="before")
@root_validator
@classmethod
def dimension_validation(cls, input_dict: Any) -> Any:
assert isinstance(input_dict, dict)
Expand Down Expand Up @@ -120,16 +119,19 @@ class Config:

open_compass_evaluator: BaseEvaluator

@model_validator(mode="after")
def _check_open_compass_evaluator(self) -> "OpenCompassLocalEvaluator":
signature = inspect.signature(self.open_compass_evaluator.score)
@root_validator
def _check_open_compass_evaluator(
cls, values: Dict[str, Any]
) -> Dict[str, Any]:
open_compass_evaluator = values["open_compass_evaluator"]
signature = inspect.signature(open_compass_evaluator.score)
params = list(signature.parameters.keys())
params.sort()
if params != ["predictions", "references"]:
raise ValueError(
f"unsupported opencompass evaluator {self.open_compass_evaluator}"
f"unsupported opencompass evaluator {type(open_compass_evaluator)}"
)
return self
return values

def evaluate(
self, input: Union[str, List[Dict[str, Any]]], reference: str, output: str
Expand Down
3 changes: 1 addition & 2 deletions src/qianfan/model/configs.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Any

from pydantic import BaseModel

from qianfan.model.consts import ServiceType
from qianfan.pydantic import BaseModel
from qianfan.resources.console import consts as console_consts


Expand Down
22 changes: 22 additions & 0 deletions src/qianfan/pydantic/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright (c) 2023 Baidu, Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Pydantic compatible layer
"""

try:
from pydantic.v1 import * # noqa
except ImportError:
from pydantic import * # noqa
11 changes: 11 additions & 0 deletions src/qianfan/tests/config_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""
Unit test for config
"""
import os

from qianfan import get_config
from qianfan.consts import DefaultValue

Expand All @@ -24,3 +26,12 @@ def test_rewrite_config_through_code():
assert get_config().AUTH_TIMEOUT == DefaultValue.AuthTimeout
config_center.AUTH_TIMEOUT = 114514
assert get_config().AUTH_TIMEOUT == 114514


def test_read_from_dot_env():
try:
with open(".env", "w") as f:
f.write('QIANFAN_ACCESS_TOKEN="test_token"')
assert get_config().ACCESS_TOKEN == "test_token"
finally:
os.remove(".env")
2 changes: 1 addition & 1 deletion src/qianfan/tests/dataset/dataset_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
from unittest.mock import patch

import pytest
from pydantic import BaseModel

from qianfan.dataset.consts import QianfanDataGroupColumnName
from qianfan.dataset.data_operator import FilterCheckNumberWords
Expand All @@ -30,6 +29,7 @@
QianfanNonSortedConversation,
QianfanSortedConversation,
)
from qianfan.pydantic import BaseModel
from qianfan.resources.console.consts import DataTemplateType


Expand Down
6 changes: 4 additions & 2 deletions src/qianfan/tests/tool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,8 @@ def test_tool_from_langchain_tool():
return

from langchain.tools.base import BaseTool as LangchainBaseTool
from pydantic.v1 import BaseModel, Field

from qianfan.pydantic import BaseModel, Field

class CalculatorToolSchema(BaseModel):
a: int = Field(description="a description")
Expand Down Expand Up @@ -145,7 +146,8 @@ def test_tool_from_langchain_func_tool():
return

from langchain.tools.base import Tool as LangchainTool
from pydantic.v1 import BaseModel, Field

from qianfan.pydantic import BaseModel, Field

def hello(a: str, b: str) -> str:
return f"hello {a} {b}"
Expand Down
3 changes: 1 addition & 2 deletions src/qianfan/trainer/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
# limitations under the License.
from typing import Any, Dict, List, Optional, Tuple, Union

from pydantic import BaseModel

from qianfan.pydantic import BaseModel
from qianfan.trainer.consts import PeftType


Expand Down

0 comments on commit 32a89b3

Please sign in to comment.