diff --git a/llmkira/error.py b/llmkira/error.py index b9719378d..faea86b35 100644 --- a/llmkira/error.py +++ b/llmkira/error.py @@ -4,11 +4,15 @@ # @File : error.py # @Software: PyCharm + class ReplyNeededError(Exception): """ Raised a error that need reply """ - pass + + def __init__(self, message: str = None, *args): + # 拦截 url 信息 + super().__init__(message, *args) # 更安全的 format diff --git a/llmkira/extra/plugins/_finish.py b/llmkira/extra/plugins/_finish.py index 4b3ca3228..bb4359455 100644 --- a/llmkira/extra/plugins/_finish.py +++ b/llmkira/extra/plugins/_finish.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from pydantic import ConfigDict + __package__name__ = "llmkira.extra.plugins.finish" __plugin_name__ = "finish_conversation" __openapi_version__ = "20231111" @@ -31,9 +33,7 @@ class Finish(BaseModel): comment: str = Field(default=":)", description="end with a question or a comment.(__language: $context)") - - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class FinishTool(BaseTool): @@ -99,7 +99,7 @@ async def run(self, """ 处理message,返回message """ - _set = Finish.parse_obj(arg) + _set = Finish.model_validate(arg) # META _meta = task.task_meta.reply_message( plugin_name=__plugin_name__, diff --git a/llmkira/extra/plugins/_translate_doc.py b/llmkira/extra/plugins/_translate_doc.py index 526e90664..eae23a56e 100644 --- a/llmkira/extra/plugins/_translate_doc.py +++ b/llmkira/extra/plugins/_translate_doc.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from pydantic import ConfigDict + __package__name__ = "llmkira.extra.plugins.translate_file" __plugin_name__ = "translate_file" __openapi_version__ = "20231027" @@ -44,9 +46,7 @@ class Translate(BaseModel): language: str file_id: str - - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class TranslateTool(BaseTool): @@ -179,7 +179,7 @@ async def run(self, for i in item.file: _translate_file.append(i) try: - translate_arg = Translate.parse_obj(arg) + translate_arg = Translate.model_validate(arg) except Exception: raise ValueError("Please specify the following parameters clearly\n file_id=xxx,language=xxx") _file_obj = [await i.raw_file() diff --git a/llmkira/extra/plugins/alarm/__init__.py b/llmkira/extra/plugins/alarm/__init__.py index 8d677c25f..d6f682c77 100644 --- a/llmkira/extra/plugins/alarm/__init__.py +++ b/llmkira/extra/plugins/alarm/__init__.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from pydantic import field_validator, ConfigDict + __package__name__ = "llmkira.extra.plugins.alarm" __plugin_name__ = "set_alarm_reminder" __openapi_version__ = "20231111" @@ -10,10 +12,9 @@ import datetime import re -import time from loguru import logger -from pydantic import validator, BaseModel +from pydantic import BaseModel from llmkira.receiver.aps import SCHEDULER from llmkira.schema import RawMessage @@ -21,7 +22,7 @@ from llmkira.sdk.func_calling.schema import FuncPair from llmkira.task import Task, TaskHeader -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from llmkira.sdk.schema import TaskBatch @@ -44,11 +45,9 @@ class Alarm(BaseModel): delay: int content: str + model_config = ConfigDict(extra="allow") - class Config: - extra = "allow" - - @validator("delay") + @field_validator("delay") def delay_validator(cls, v): if v < 0: raise ValueError("delay must be greater than 0") @@ -58,9 +57,9 @@ def delay_validator(cls, v): async def send_notify(_platform, _meta, _sender: dict, _receiver: dict, _user, _chat, _content: str): await Task(queue=_platform).send_task( task=TaskHeader( - sender=TaskHeader.Location.parse_obj(_sender), # 继承发送者 - receiver=TaskHeader.Location.parse_obj(_receiver), # 因为可能有转发,所以可以单配 - task_meta=TaskHeader.Meta.parse_obj(_meta), + sender=TaskHeader.Location.model_validate(_sender), # 继承发送者 + receiver=TaskHeader.Location.model_validate(_receiver), # 因为可能有转发,所以可以单配 + task_meta=TaskHeader.Meta.model_validate(_meta), message=[ RawMessage( user_id=_user, @@ -79,7 +78,7 @@ class AlarmTool(BaseTool): silent: bool = False function: Function = alarm keywords: list = ["闹钟", "提醒", "定时", "到点", '分钟'] - pattern = re.compile(r"(\d+)(分钟|小时|天|周|月|年)后提醒我(.*)") + pattern: Optional[re.Pattern] = re.compile(r"(\d+)(分钟|小时|天|周|月|年)后提醒我(.*)") require_auth: bool = True # env_required: list = ["SCHEDULER", "TIMEZONE"] @@ -148,7 +147,7 @@ async def run(self, """ 处理message,返回message """ - _set = Alarm.parse_obj(arg) + _set = Alarm.model_validate(arg) _meta = task.task_meta.reply_message( plugin_name=__plugin_name__, callback=[ @@ -163,22 +162,19 @@ async def run(self, logger.debug("Plugin:set alarm {} minutes later".format(_set.delay)) SCHEDULER.add_job( func=send_notify, - id=str(time.time()), + id=str(receiver.user_id), trigger="date", replace_existing=True, + misfire_grace_time=1000, run_date=datetime.datetime.now() + datetime.timedelta(minutes=_set.delay), args=[ task.receiver.platform, - _meta.dict(), - task.sender.dict(), receiver.dict(), + _meta.model_dump(), + task.sender.model_dump(), receiver.model_dump(), receiver.user_id, receiver.chat_id, _set.content ] ) - try: - SCHEDULER.start() - except Exception as e: - print(f"[155035]{e}") await Task(queue=receiver.platform).send_task( task=TaskHeader( sender=task.sender, # 继承发送者 diff --git a/llmkira/extra/plugins/search.py b/llmkira/extra/plugins/search.py index e45c402c3..ab9b9467a 100644 --- a/llmkira/extra/plugins/search.py +++ b/llmkira/extra/plugins/search.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from pydantic import ConfigDict + __package__name__ = "llmkira.extra.plugins.search" __plugin_name__ = "search_in_google" __openapi_version__ = "20231111" @@ -74,9 +76,7 @@ def search_on_duckduckgo(search_sentence: str, key_words: str = None): class Search(BaseModel): keywords: str - - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class SearchTool(BaseTool): @@ -200,7 +200,7 @@ async def run(self, 处理message,返回message """ - _set = Search.parse_obj(arg) + _set = Search.model_validate(arg) _search_result = search_on_duckduckgo(_set.keywords) # META _meta = task.task_meta.reply_raw( diff --git a/llmkira/extra/plugins/sticker.py b/llmkira/extra/plugins/sticker.py index 4745846ba..30cd30b6d 100644 --- a/llmkira/extra/plugins/sticker.py +++ b/llmkira/extra/plugins/sticker.py @@ -1,4 +1,6 @@ # -*- coding: utf-8 -*- +from pydantic import field_validator, ConfigDict + __package__name__ = "llmkira.extra.plugins.sticker" __plugin_name__ = "convert_to_sticker" __openapi_version__ = "20231111" @@ -15,13 +17,13 @@ from PIL import Image from loguru import logger -from pydantic import validator, BaseModel +from pydantic import BaseModel from llmkira.schema import RawMessage from llmkira.sdk.func_calling import BaseTool from llmkira.sdk.func_calling.schema import FuncPair, PluginMetadata from llmkira.task import Task, TaskHeader -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional if TYPE_CHECKING: from llmkira.sdk.schema import TaskBatch @@ -43,11 +45,9 @@ class Sticker(BaseModel): yes_no: str comment: str = "done" + model_config = ConfigDict(extra="allow") - class Config: - extra = "allow" - - @validator("yes_no") + @field_validator("yes_no") def delay_validator(cls, v): if v != "yes": v = "no" @@ -86,7 +86,7 @@ class StickerTool(BaseTool): """ function: Function = sticker keywords: list = ["转换", "贴纸", ".jpg", "图像", '图片'] - file_match_required = re.compile(r"(.+).jpg|(.+).png|(.+).jpeg|(.+).gif|(.+).webp|(.+).svg") + file_match_required: Optional[re.Pattern] = re.compile(r"(.+).jpg|(.+).png|(.+).jpeg|(.+).gif|(.+).webp|(.+).svg") def pre_check(self): return True @@ -158,7 +158,7 @@ async def run(self, if item.file: for i in item.file: _file.append(i) - _set = Sticker.parse_obj(arg) + _set = Sticker.model_validate(arg) _file_obj = [await i.raw_file() for i in sorted(set(_file), key=_file.index)] # 去掉None _file_obj = [item for item in _file_obj if item] diff --git a/llmkira/extra/user/client.py b/llmkira/extra/user/client.py index 210861281..f859bce1d 100644 --- a/llmkira/extra/user/client.py +++ b/llmkira/extra/user/client.py @@ -27,7 +27,7 @@ def __init__(self): self.client = self.use_collection("user_cost") async def insert(self, data: UserCost): - await self.client.insert_one(data.dict()) + await self.client.insert_one(data.model_dump()) return data async def read_by_uid(self, uid: str) -> List[UserCost]: @@ -50,9 +50,9 @@ async def update(self, uid: str, data: UserConfig, validate: bool = True) -> Use [("uid", 1)], unique=True ) try: - await self.client.insert_one(data.dict()) + await self.client.insert_one(data.model_dump(mode="json")) except DuplicateKeyError: - await self.client.update_one({"uid": uid}, {"$set": data.dict()}) + await self.client.update_one({"uid": uid}, {"$set": data.model_dump(mode="json")}) return data async def read_by_uid(self, uid: str) -> Optional[UserConfig]: diff --git a/llmkira/extra/user/schema.py b/llmkira/extra/user/schema.py index ec45a78ae..4d28be8d0 100644 --- a/llmkira/extra/user/schema.py +++ b/llmkira/extra/user/schema.py @@ -8,7 +8,8 @@ from enum import Enum from typing import List, Union, Optional -from pydantic import BaseModel, Field, BaseSettings, validator +from pydantic import field_validator, ConfigDict, BaseModel, Field +from pydantic_settings import BaseSettings, SettingsConfigDict from ...sdk.endpoint import Driver @@ -35,17 +36,17 @@ class Cost(BaseModel): """ cost_by: str = Field("chat", description="环节") token_usage: int = Field(0) - token_uuid: str = Field(None, description="Api Key 的 hash") - model_name: str = Field(None, description="Model Name") + token_uuid: Optional[str] = Field(None, description="Api Key 的 hash") + llm_model: Optional[str] = Field(None, description="Model Name") provide_type: int = Field(None, description="认证模式") @classmethod def by_function(cls, function_name: str, token_usage: int, token_uuid: str, - model_name: str, + llm_model: str, ): - return cls(cost_by=function_name, token_usage=token_usage, token_uuid=token_uuid, model_name=model_name) + return cls(cost_by=function_name, token_usage=token_usage, token_uuid=token_uuid, model_name=llm_model) request_id: str = Field(default=None, description="请求 UUID") uid: str = Field(default=None, description="用户 UID ,注意是平台+用户") @@ -70,7 +71,7 @@ def create_from_function( function_name=cost_by, token_usage=token_usage, token_uuid=token_uuid, - model_name=model_name, + llm_model=model_name, ), cost_time=int(time.time()), ) @@ -89,13 +90,11 @@ def create_from_task( cost_time=int(time.time()), ) - class Config: - extra = "ignore" - allow_mutation = True - arbitrary_types_allowed = True - validate_assignment = True - validate_all = True - validate_on_assignment = True + model_config = ConfigDict(extra="ignore", + arbitrary_types_allowed=True, + validate_assignment=True, + validate_default=True + ) class UserConfig(BaseSettings): @@ -133,7 +132,7 @@ def set_proxy_public(self, token: str, provider: str): self.token = token return self - @validator("provider") + @field_validator("provider") def upper_provider(cls, v): if v: return v.upper() @@ -146,12 +145,12 @@ class PluginConfig(BaseModel): def default(cls): return cls() - def block(self, plugin_name: str) -> "PluginConfig": + def block(self, plugin_name: str) -> "UserConfig.PluginConfig": if plugin_name not in self.block_list: self.block_list.append(plugin_name) return self - def unblock(self, plugin_name: str) -> "PluginConfig": + def unblock(self, plugin_name: str) -> "UserConfig.PluginConfig": if plugin_name in self.block_list: self.block_list.remove(plugin_name) return self @@ -162,16 +161,15 @@ def unblock(self, plugin_name: str) -> "PluginConfig": plugin_subs: PluginConfig = Field(default_factory=PluginConfig.default, description="插件订阅") llm_driver: LlmConfig = Field(default_factory=LlmConfig.default, description="驱动") - @validator("uid") + @field_validator("uid") def check_user_id(cls, v): if v: return str(v) return v - class Config: - extra = "ignore" - allow_mutation = True - arbitrary_types_allowed = True - validate_assignment = True - validate_all = True - validate_on_assignment = True + model_config = SettingsConfigDict(extra="ignore", + frozen=True, + arbitrary_types_allowed=True, + validate_assignment=True, + validate_default=True + ) diff --git a/llmkira/middleware/chain_box/__init__.py b/llmkira/middleware/chain_box/__init__.py index 34e8226e9..b834e3960 100644 --- a/llmkira/middleware/chain_box/__init__.py +++ b/llmkira/middleware/chain_box/__init__.py @@ -7,8 +7,8 @@ from loguru import logger -from ...sdk.cache.redis import cache from .schema import Chain +from ...sdk.cache.redis import cache from ...task import TaskHeader @@ -27,7 +27,7 @@ def from_meta(cls, platform: str, user_id: str): return _c async def add_auth(self, chain: Chain): - _cache = await cache.set_data(key=f"auth:{chain.uuid}", value=chain.json(), timeout=60 * 60 * 24 * 7) + _cache = await cache.set_data(key=f"auth:{chain.uuid}", value=chain.model_dump_json(), timeout=60 * 60 * 24 * 7) return chain.uuid async def get_auth(self, uuid: str) -> Optional[Chain]: @@ -53,7 +53,7 @@ def __init__(self, uid: str): self.uid = uid async def add_task(self, chain: Chain): - _cache = await cache.lpush_data(key=f"chain:{self.uid}", value=chain.json()) + _cache = await cache.lpush_data(key=f"chain:{self.uid}", value=chain.model_dump_json()) return chain.uuid async def get_task(self) -> Optional[Chain]: diff --git a/llmkira/middleware/chain_box/schema.py b/llmkira/middleware/chain_box/schema.py index b683d3e40..7a490c049 100644 --- a/llmkira/middleware/chain_box/schema.py +++ b/llmkira/middleware/chain_box/schema.py @@ -4,10 +4,10 @@ # @File : schema.py # @Software: PyCharm import time -from typing import Any, Type +from typing import Any, Type, Optional import shortuuid -from pydantic import BaseModel, Field, validator +from pydantic import field_validator, BaseModel, Field class Chain(BaseModel): @@ -34,13 +34,13 @@ def from_redis(cls, data: dict): data["deprecated"] = True if not data.get("expire"): data["deprecated"] = True - return cls.parse_obj(data) + return cls.model_validate(data) @property def is_expire(self): return (int(time.time()) - self.created_times > self.expire) or self.deprecated - @validator("uid") + @field_validator("uid") def check_user_id(cls, v): if v.count(":") != 1: raise ValueError("Chain:uid format error") @@ -48,7 +48,7 @@ def check_user_id(cls, v): raise ValueError("Chain:uid is empty") return v - @validator("address") + @field_validator("address") def check_address(cls, v): if not v: raise ValueError("Chain:address is empty") @@ -59,5 +59,5 @@ def format_arg(self, arg: Type[BaseModel]): 神祗的格式化 """ if isinstance(self.arg, dict): - self.arg = arg.parse_obj(self.arg) + self.arg = arg.model_validate(self.arg) return self diff --git a/llmkira/middleware/llm_task.py b/llmkira/middleware/llm_task.py index 83712bd03..5f944c2d3 100644 --- a/llmkira/middleware/llm_task.py +++ b/llmkira/middleware/llm_task.py @@ -108,6 +108,9 @@ def __init__(self, ) def init(self): + """ + :raise: ProviderException + """ # 构建请求的驱动信息 self.auth_client = GetAuthDriver(uid=self.session_user_uid) self.driver = sync(self.auth_client.get()) @@ -132,13 +135,14 @@ def _unique_function(self): def write_back( self, *, - message: Message + message: Optional[Message] = None ): """ 写回消息到 Redis 数据库中 function 写回必须指定 name """ - self.message_history.add_message(message=message) + if message: + self.message_history.add_message(message=message) def _append_function_tools(self, functions: List[Function]): """ @@ -214,7 +218,7 @@ async def request_openai( :param retrieve_mode: 是否为检索模式,当我们需要重新处理超长消息时候,需要设定为 True :return: OpenaiResult """ - run_driver_model = self.driver.model if not retrieve_mode else self.driver.model_retrieve + run_driver_model = self.driver.model if not retrieve_mode else self.driver.retrieve_model endpoint_schema = self.get_schema(model_name=run_driver_model) # 添加函数定义的系统提示 if not disable_function: @@ -281,6 +285,7 @@ async def request_openai( _message = result.default_message _usage = result.usage.total_tokens self.message_history.add_message(message=_message) + # print(result.model_dump_json(indent=2)) # 记录消耗 await CostControl.add_cost( cost=UserCost.create_from_task( @@ -290,7 +295,7 @@ async def request_openai( cost_by=self.task.receiver.platform, token_usage=_usage, token_uuid=self.driver.uuid, - model_name=self.driver.model, + llm_model=self.driver.model, provide_type=self.auth_client.provide_type().value ) ) diff --git a/llmkira/middleware/router/__init__.py b/llmkira/middleware/router/__init__.py index 4f623e07a..1ab7ed282 100644 --- a/llmkira/middleware/router/__init__.py +++ b/llmkira/middleware/router/__init__.py @@ -19,15 +19,15 @@ def __init__(self): async def _upload(self): assert isinstance(self.router, RouterCache), "router info error" - self.router = RouterCache.parse_obj(self.router.dict()) - return await cache.set_data(key=self.__redis_key__, value=self.router.json()) + # self.router = RouterCache.model_validate(self.router.dict()) + return await cache.set_data(key=self.__redis_key__, value=self.router.model_dump_json()) async def _sync(self) -> RouterCache: _cache = await cache.read_data(key=self.__redis_key__) if not _cache: return RouterCache() try: - sub_info = RouterCache().parse_obj(_cache) + sub_info = RouterCache().model_validate(_cache) except Exception: raise Exception(f"not found router info") return sub_info diff --git a/llmkira/middleware/router/schema.py b/llmkira/middleware/router/schema.py index 87cbef9c3..f6190246e 100644 --- a/llmkira/middleware/router/schema.py +++ b/llmkira/middleware/router/schema.py @@ -5,7 +5,7 @@ # @Software: PyCharm from typing import Literal, List -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator SENDER = {} RECEIVER = {} @@ -52,7 +52,7 @@ def dsn(self, user_dsn=False): class RouterCache(BaseModel): router: List[Router] = [] - @validator("router", always=True) + @field_validator("router") def router_validate(cls, v): _dict = {} for item in v: diff --git a/llmkira/middleware/service_provider/public.py b/llmkira/middleware/service_provider/public.py index 079c533fa..b6061c6e7 100644 --- a/llmkira/middleware/service_provider/public.py +++ b/llmkira/middleware/service_provider/public.py @@ -58,20 +58,20 @@ async def check_times(self, times: int, uid: str): return True logger.debug(f"🍦 Public Provider Check Times UID({uid}) Read({read})") if read: - _data: UserToday = UserToday.parse_obj(read) + _data: UserToday = UserToday.model_validate(read) if str(_data.time) != str(date): - await cache.set_data(self.__database_key(uid=uid), value=UserToday().dict()) + await cache.set_data(self.__database_key(uid=uid), value=UserToday().model_dump()) return True else: if _data.count > times: return False if _data.count < times: _data.count += 1 - await cache.set_data(self.__database_key(uid=uid), value=_data.dict()) + await cache.set_data(self.__database_key(uid=uid), value=_data.model_dump()) return True else: _data = UserToday() - await cache.set_data(self.__database_key(uid=uid), value=_data.dict()) + await cache.set_data(self.__database_key(uid=uid), value=_data.model_dump()) return True return False diff --git a/llmkira/middleware/service_provider/schema.py b/llmkira/middleware/service_provider/schema.py index 1c81b1070..c3e795422 100644 --- a/llmkira/middleware/service_provider/schema.py +++ b/llmkira/middleware/service_provider/schema.py @@ -5,19 +5,20 @@ # @Software: PyCharm from abc import ABC, abstractmethod -from pydantic import BaseSettings, Field, validator +from pydantic import field_validator, Field from llmkira.sdk.endpoint import Driver +from pydantic_settings import BaseSettings class ProviderSetting(BaseSettings): - provider: str = Field("PUBLIC", env="SERVICE_PROVIDER") + provider: str = Field("PUBLIC", validation_alias="SERVICE_PROVIDER") @property def is_open_everyone(self): return self.provider.upper() == "PUBLIC" - @validator("provider") + @field_validator("provider") def provider_upper(cls, v): return v.upper() diff --git a/llmkira/receiver/discord/__init__.py b/llmkira/receiver/discord/__init__.py index ce00638fc..a859a9f43 100644 --- a/llmkira/receiver/discord/__init__.py +++ b/llmkira/receiver/discord/__init__.py @@ -186,7 +186,7 @@ async def function(self, llm_result=llm_result, receiver=receiver ) - new_receiver = task.receiver.copy() + new_receiver = task.receiver.model_copy() new_receiver.platform = __receiver__ """更新接收者为当前平台,便于创建的函数消息能返回到正确的客户端""" new_meta = task.task_meta.pack_loop( diff --git a/llmkira/receiver/function.py b/llmkira/receiver/function.py index 063816821..385c5c17e 100644 --- a/llmkira/receiver/function.py +++ b/llmkira/receiver/function.py @@ -60,7 +60,7 @@ async def auth_chain(self, """ 认证链重发注册 """ - _task_forward: TaskHeader = task.copy() + _task_forward: TaskHeader = task.model_copy() meta: TaskHeader.Meta = _task_forward.task_meta.chain( name=__receiver__, write_back=False, # 因为是发送给自己,所以不需要写回 @@ -105,7 +105,7 @@ async def resign_chain( :param repeatable: 是否可重复 :param deploy_child: 是否部署子链 """ - _task_forward: TaskHeader = task.copy() + _task_forward: TaskHeader = task.model_copy() # 添加认证链并重置路由数据 meta: TaskHeader.Meta = _task_forward.task_meta.chain( name=__receiver__, @@ -189,7 +189,7 @@ async def run_pending_task(task: TaskHeader, pending_task: TaskBatch): task=task, text=f"🔭 Sorry function `{pending_task.get_batch_name()}` executor not found" ) - return ModuleNotFoundError(f"Function {pending_task.get_batch_name()} not found") + raise ModuleNotFoundError(f"Function {pending_task.get_batch_name()} not found") # Run Function _tool_obj = _tool_cls() if _tool_obj.require_auth: @@ -232,7 +232,7 @@ async def process_function_call(self, message: AbstractIncomingMessage # Parse Message if os.getenv("LLMBOT_STOP_REPLY") == "1": return None - task: TaskHeader = TaskHeader.parse_raw(message.body.decode("utf-8")) + task: TaskHeader = TaskHeader.model_validate_json(json_data=message.body.decode("utf-8")) # Get Function Call pending_task = await task.task_meta.work_pending_task( verify_uuid=task.task_meta.verify_uuid diff --git a/llmkira/receiver/kook/__init__.py b/llmkira/receiver/kook/__init__.py index 8a9dadd63..30584a447 100644 --- a/llmkira/receiver/kook/__init__.py +++ b/llmkira/receiver/kook/__init__.py @@ -170,7 +170,7 @@ async def function(self, llm_result=llm_result, receiver=receiver ) - new_receiver = task.receiver.copy() + new_receiver = task.receiver.model_copy() new_receiver.platform = __receiver__ """更新接收者为当前平台,便于创建的函数消息能返回到正确的客户端""" new_meta = task.task_meta.pack_loop( diff --git a/llmkira/receiver/receiver_client.py b/llmkira/receiver/receiver_client.py index 7a5ffdcf9..bbdd479c0 100644 --- a/llmkira/receiver/receiver_client.py +++ b/llmkira/receiver/receiver_client.py @@ -12,9 +12,11 @@ from abc import ABCMeta, abstractmethod from typing import Optional, Tuple, List +import httpx import shortuuid from aio_pika.abc import AbstractIncomingMessage from loguru import logger +from pydantic import ValidationError as PydanticValidationError from telebot import formatting from llmkira.error import get_request_error_message, ReplyNeededError @@ -89,19 +91,13 @@ async def push_task_create_message(self, ): auth_map = {} - async def _action_block(_task_batch: TaskBatch): + async def _action_block(_task_batch: TaskBatch) -> Tuple[List[str], bool]: _tool = ToolRegister().get_tool(_task_batch.get_batch_name()) if not _tool: logger.warning(f"not found function {_task_batch.get_batch_name()}") - return await self.forward( - receiver=receiver, - message=[ - RawMessage( - text=f"🔭 Sorry function `{_task_batch.get_batch_name()}` not found", - only_send_file=False - ) - ] - ) + return [ + formatting.mbold("🍩 [Unknown]") + f" `{_task_batch.get_batch_name()}` " + ], False tool = _tool() icon = "🌟" if tool.require_auth: @@ -136,6 +132,7 @@ async def _action_block(_task_batch: TaskBatch): formatting.mbold("💫 Plan") + f" `{llm_result.id[-4:]}` ", ] total_silent = True + assert isinstance(task_batch, list), f"task batch type error {type(task_batch)}" for _task_batch in task_batch: _message, _silent = await _action_block(_task_batch=_task_batch) if not _silent: @@ -203,16 +200,22 @@ async def llm_request( return _result except ssl.SSLSyscallError as e: logger.error(f"[Network ssl error] {e},that maybe caused by bad proxy") - raise ReplyNeededError(e) + raise Exception(e) + except httpx.RemoteProtocolError as e: + logger.error(f"[Network RemoteProtocolError] {e}") + raise ReplyNeededError(message=f"Server disconnected without sending a response.") except ServiceUnavailableError as e: logger.error(f"[Service Unavailable Error] {e}") - raise ReplyNeededError(e) + raise ReplyNeededError(message=f"[551721]Service Unavailable {e}") except RateLimitError as e: logger.error(f"ApiEndPoint:{e}") - raise ReplyNeededError(e) + raise ReplyNeededError(message=f"[551580]Rate Limit Error {e}") except ProviderException as e: logger.info(f"[Service Provider]{e}") - raise ReplyNeededError(e) + raise ReplyNeededError(message=f"[551183]Service Provider Error {e}") + except PydanticValidationError as e: + logger.exception(e) + raise ReplyNeededError(message=f"[551684]Request Data ValidationError") except Exception as e: logger.exception(e) raise e @@ -249,13 +252,12 @@ async def _flash(self, if not isinstance(get_message, AssistantMessage): raise ReplyNeededError("[55682]Request Result Not Valid, Must Be `AssistantMessage`") except Exception as e: - await self.sender.error( - receiver=task.receiver, - text=get_request_error_message(str(e)) - ) - if not isinstance(e, ReplyNeededError): - raise e - return None + if isinstance(e, ReplyNeededError): + await self.sender.error( + receiver=task.receiver, + text=get_request_error_message(str(e)) + ) + raise e if intercept_function: if get_message.sign_function: await self.sender.reply( @@ -284,7 +286,7 @@ async def deal_message(self, message) -> Tuple[ 处理消息 """ logger.debug(f"[x] Received Message \n--message {message.body}") - _task: TaskHeader = TaskHeader.parse_raw(message.body.decode("utf-8")) + _task: TaskHeader = TaskHeader.model_validate_json(message.body.decode("utf-8")) # 没有任何参数 if _task.task_meta.direct_reply: await self.sender.forward( @@ -295,16 +297,22 @@ async def deal_message(self, message) -> Tuple[ functions = await FunctionReorganize(task=_task).build_arg() """函数组建,自动过滤拉黑后的插件和错误过多的插件""" - - _llm = OpenaiMiddleware( - task=_task, - functions=functions, - tools=[] - # 内部会初始化函数工具,这里是其他类型工具 - ).init() + try: + _llm = OpenaiMiddleware( + task=_task, + functions=functions, + tools=[] + # 内部会初始化函数工具,这里是其他类型工具 + ).init() + except ProviderException as e: + await self.sender.error( + receiver=_task.receiver, + text=f"🥞 Auth System Report {formatting.escape_markdown(str(e))}" + ) + raise e """构建通信代理""" schema = _llm.get_schema() - logger.debug(f"[x] Received Order \n--order {_task.json()}") + logger.debug(f"[x] Received Order \n--order {_task.model_dump_json()}") # function_response write back if _task.task_meta.write_back: for call in _task.task_meta.callback: @@ -312,9 +320,12 @@ async def deal_message(self, message) -> Tuple[ _func_tool_msg = call.get_tool_message() elif schema.func_executor == "function_call": _func_tool_msg = call.get_function_message() + elif schema.func_executor == "unsupported": + _func_tool_msg = None else: raise NotImplementedError(f"func_executor {schema.func_executor} not implemented") """消息类型是由请求结果决定的。也就是理论不存在预料外的冲突。""" + _llm.write_back( message=_func_tool_msg ) @@ -353,7 +364,6 @@ async def on_message(self, message: AbstractIncomingMessage): try: if os.getenv("LLMBOT_STOP_REPLY") == "1": return None - # 处理消息 task, llm, point, release = await self.deal_message(message) # 启动链式函数应答循环 @@ -362,7 +372,6 @@ async def on_message(self, message: AbstractIncomingMessage): if chain: await Task(queue=chain.address).send_task(task=chain.arg) logger.info(f"🧀 Chain point release\n--callback_send_by {point}") - except Exception as e: logger.exception(e) await message.reject(requeue=False) diff --git a/llmkira/receiver/slack/__init__.py b/llmkira/receiver/slack/__init__.py index 7a2d208a4..b9da4ef21 100644 --- a/llmkira/receiver/slack/__init__.py +++ b/llmkira/receiver/slack/__init__.py @@ -148,7 +148,7 @@ async def function(self, llm_result=llm_result, receiver=receiver ) - new_receiver = task.receiver.copy() + new_receiver = task.receiver.model_copy() new_receiver.platform = __receiver__ """更新接收者为当前平台,便于创建的函数消息能返回到正确的客户端""" new_meta = task.task_meta.pack_loop( diff --git a/llmkira/receiver/telegram/__init__.py b/llmkira/receiver/telegram/__init__.py index 9106ac302..798e7ecce 100644 --- a/llmkira/receiver/telegram/__init__.py +++ b/llmkira/receiver/telegram/__init__.py @@ -16,7 +16,6 @@ from llmkira.schema import RawMessage from llmkira.sdk.endpoint.schema import LlmResult from llmkira.sdk.schema import Message, File -from llmkira.sdk.utils import sync from llmkira.setting.telegram import BotSetting from llmkira.task import Task, TaskHeader @@ -98,8 +97,9 @@ async def forward(self, receiver: TaskHeader.Location, message: List[RawMessage] reply_to_message_id=receiver.message_id, parse_mode="MarkdownV2" ) + # TODO Telegram format except telebot.apihelper.ApiTelegramException as e: - time.sleep(3) + time.sleep(1) logger.error(f"telegram send message error, retry\n{e}") self.bot.send_message( chat_id=receiver.chat_id, @@ -150,7 +150,7 @@ async def function(self, llm_result=llm_result, receiver=receiver ) - new_receiver = task.receiver.copy() + new_receiver = task.receiver.model_copy() new_receiver.platform = __receiver__ """更新接收者为当前平台,便于创建的函数消息能返回到正确的客户端""" new_meta = task.task_meta.pack_loop( diff --git a/llmkira/schema.py b/llmkira/schema.py index 6b0e1f45d..c75672be0 100644 --- a/llmkira/schema.py +++ b/llmkira/schema.py @@ -4,11 +4,11 @@ # @File : schema.py # @Software: PyCharm import time -from typing import TYPE_CHECKING, Literal, Type +from typing import TYPE_CHECKING, Literal, Type, Optional from typing import Union, List import nest_asyncio -from pydantic import Field, BaseModel, validator +from pydantic import field_validator, ConfigDict, Field, BaseModel from .sdk.endpoint.tokenizer import get_tokenizer from .sdk.schema import File, generate_uid, UserMessage, Message @@ -22,9 +22,9 @@ class RawMessage(BaseModel): user_id: Union[int, str] = Field(None, description="user id") chat_id: Union[int, str] = Field(None, description="guild id(channel in dm)/Telegram chat id") - thread_id: Union[int, str] = Field(None, description="channel id/Telegram thread") + thread_id: Optional[Union[int, str]] = Field(None, description="channel id/Telegram thread") - text: str = Field(None, description="文本") + text: str = Field("", description="文本") file: List[File] = Field([], description="文件") created_at: str = Field(default=str(int(time.time())), description="创建时间") @@ -33,12 +33,9 @@ class RawMessage(BaseModel): sign_loop_end: bool = Field(default=False, description="要求其他链条不处理此消息,用于拦截器开发") sign_fold_docs: bool = Field(default=False, description="是否获取元数据") extra_kwargs: dict = Field(default={}, description="extra kwargs for loop") + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") - class Config: - arbitrary_types_allowed = True - extra = "allow" - - @validator("text") + @field_validator("text") def check_text(cls, v): if v == "": v = "" @@ -79,7 +76,7 @@ def format_openai_message( user_id=locate.user_id, text=message.content, chat_id=locate.chat_id, - created_at=int(time.time()) + created_at=str(int(time.time())) ) diff --git a/llmkira/sdk/adapter.py b/llmkira/sdk/adapter.py index 05455c646..7a8564c86 100644 --- a/llmkira/sdk/adapter.py +++ b/llmkira/sdk/adapter.py @@ -15,9 +15,9 @@ class SingleModel(BaseModel): - model_name: str + llm_model: str token_limit: int - func_executor: Literal["function_call", "tool_call"] + func_executor: Literal["function_call", "tool_call", "unsupported"] request: Type["LlmRequest"] response: Type["LlmResult"] schema_type: str @@ -33,7 +33,7 @@ class ModelMeta(object): def add_model(self, models: List[SingleModel]): for model in models: if not model.exception: - logger.debug(f"🥐 [Model Available] {model.model_name}") + logger.debug(f"🥐 [Model Available] {model.llm_model}") self.model_list.append(model) def get_by_model_name(self, @@ -41,7 +41,7 @@ def get_by_model_name(self, model_name: str ) -> SingleModel: for model in self.model_list: - if model.model_name == model_name: + if model.llm_model == model_name: if model.exception: raise NotImplementedError( f"model {model_name} not implemented" @@ -49,12 +49,12 @@ def get_by_model_name(self, ) return model raise LookupError( - f"model {model_name} not found" + f"model {model_name} not found! " f"please check your model name" ) def get_model_list(self): - return [model.model_name for model in self.model_list] + return [model.llm_model for model in self.model_list] def get_token_limit(self, *, diff --git a/llmkira/sdk/endpoint/openai/__init__.py b/llmkira/sdk/endpoint/openai/__init__.py index ae927cdba..8703d7e6e 100644 --- a/llmkira/sdk/endpoint/openai/__init__.py +++ b/llmkira/sdk/endpoint/openai/__init__.py @@ -3,6 +3,8 @@ # @Author : sudoskys # @File : base.py # @Software: PyCharm +from pydantic import field_validator, ConfigDict, model_validator + __version__ = "0.0.1" from typing import Union, List, Optional, Literal, Type @@ -11,7 +13,7 @@ import pydantic from dotenv import load_dotenv from loguru import logger -from pydantic import BaseModel, root_validator, validator, Field, PrivateAttr +from pydantic import BaseModel, Field, PrivateAttr from ..schema import LlmResult, LlmRequest from ..tee import Driver @@ -50,17 +52,14 @@ def sign_function(self): "tool_calls" == self.finish_reason ) - id: Optional[str] = Field(default=None, alias="request_id") + id: str object: str created: int model: str system_fingerprint: str = Field(default=None, alias="system_prompt_fingerprint") choices: List[Choices] usage: Usage - - class Config: - arbitrary_types_allowed = True - extra = "allow" + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") @property def result_type(self): @@ -72,37 +71,35 @@ def default_message(self): class Openai(LlmRequest): - class Config: - arbitrary_types_allowed = True - extra = "allow" + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") _config: Driver = PrivateAttr(default=None) """模型信息和配置""" messages: List[Union[Message, Type[Message]]] temperature: Optional[float] = 1 n: Optional[int] = 1 - top_p: Optional[float] - stop: Optional[Union[str, List[str]]] - max_tokens: Optional[int] - presence_penalty: Optional[float] - frequency_penalty: Optional[float] - seed: Optional[int] + top_p: Optional[float] = None + stop: Optional[Union[str, List[str]]] = None + max_tokens: Optional[int] = None + presence_penalty: Optional[float] = None + frequency_penalty: Optional[float] = None + seed: Optional[int] = None """基础设置""" stream: Optional[bool] = False """暂时不打算用的流式""" - logit_bias: Optional[dict] + logit_bias: Optional[dict] = None """暂时不打算用的logit_bias""" - user: Optional[str] + user: Optional[str] = None """追踪 User""" - response_format: Optional[dict] + response_format: Optional[dict] = None """回复指定的格式,See: https://platform.openai.com/docs/api-reference/chat/create""" # 函数 - functions: Optional[List[Function]] + functions: Optional[List[Function]] = None """deprecated""" function_call: Optional[Union[BaseFunction, Literal["auto", "none"]]] = None """deprecated""" # 工具 - tools: Optional[List[Tool]] + tools: Optional[List[Tool]] = None tool_choice: Optional[Union[ToolChoice, Literal["auto", "none"]]] = None """工具调用""" @@ -112,7 +109,8 @@ class Config: 注意验证 tool choice 的字段。 """ - @root_validator + @model_validator(mode="before") + @classmethod def fix_tool(cls, values): if not values.get("tools"): values["tools"] = None @@ -182,7 +180,7 @@ def create_params(self): ) self.messages = _new_messages # - _arg = self.dict( + _arg = self.model_dump( exclude_none=True, include=self.schema_map ) @@ -195,19 +193,19 @@ def create_params(self): } return _arg - @validator("presence_penalty") + @field_validator("presence_penalty") def check_presence_penalty(cls, v): if not (2 > v > -2): raise ValidationError("presence_penalty must be between -2 and 2") return v - @validator("stop") + @field_validator("stop") def check_stop(cls, v): if isinstance(v, list) and len(v) > 4: raise ValidationError("stop list length must be less than 4") return v - @validator("temperature") + @field_validator("temperature") def check_temperature(cls, v): if not (2 > v > 0): raise ValidationError("temperature must be between 0 and 2") @@ -250,14 +248,14 @@ async def create(self, ) assert _response, ValidationError("response is empty") logger.debug(f"[Openai response] {_response}") - return_result = OpenaiResult.parse_obj(_response).ack() + return_result = OpenaiResult.model_validate(_response).ack() except httpx.ConnectError as e: logger.error(f"[Openai connect error] {e}") raise e except pydantic.ValidationError as e: logger.error(f"[Api format error] {e}") raise e - if self._echo: + if self.echo: logger.info(f"[Openai Raw response] {return_result}") return return_result @@ -265,7 +263,7 @@ async def create(self, SCHEMA_GROUP.add_model( models=[ SingleModel( - model_name="chatglm3", + llm_model="chatglm3", token_limit=4096, request=Openai, response=OpenaiResult, @@ -274,7 +272,7 @@ async def create(self, exception=None ), SingleModel( - model_name="chatglm3-16k", + llm_model="chatglm3-16k", token_limit=16384, request=Openai, response=OpenaiResult, @@ -288,7 +286,7 @@ async def create(self, SCHEMA_GROUP.add_model( models=[ SingleModel( - model_name="gpt-3.5-turbo-1106", + llm_model="gpt-3.5-turbo-1106", token_limit=16384, request=Openai, response=OpenaiResult, @@ -297,7 +295,7 @@ async def create(self, exception=None ), SingleModel( - model_name="gpt-3.5-turbo", + llm_model="gpt-3.5-turbo", token_limit=4096, request=Openai, response=OpenaiResult, @@ -306,7 +304,7 @@ async def create(self, exception=None ), SingleModel( - model_name="gpt-3.5-turbo-16k", + llm_model="gpt-3.5-turbo-16k", token_limit=16384, request=Openai, response=OpenaiResult, @@ -315,7 +313,7 @@ async def create(self, exception=None ), SingleModel( - model_name="gpt-3.5-turbo-0613", + llm_model="gpt-3.5-turbo-0613", token_limit=4096, request=Openai, response=OpenaiResult, @@ -324,7 +322,7 @@ async def create(self, exception=None ), SingleModel( - model_name="gpt-3.5-turbo-16k-0613", + llm_model="gpt-3.5-turbo-16k-0613", token_limit=16384, request=Openai, response=OpenaiResult, @@ -333,7 +331,7 @@ async def create(self, exception=None ), SingleModel( - model_name="gpt-4", + llm_model="gpt-4", token_limit=8192, request=Openai, response=OpenaiResult, @@ -342,7 +340,7 @@ async def create(self, exception=None ), SingleModel( - model_name="gpt-4-32k", + llm_model="gpt-4-32k", token_limit=32768, request=Openai, response=OpenaiResult, @@ -351,7 +349,7 @@ async def create(self, exception=None ), SingleModel( - model_name="gpt-4-0613", + llm_model="gpt-4-0613", token_limit=8192, request=Openai, response=OpenaiResult, @@ -360,7 +358,7 @@ async def create(self, exception=None ), SingleModel( - model_name="gpt-4-vision-preview", + llm_model="gpt-4-vision-preview", token_limit=128000, request=Openai, response=OpenaiResult, @@ -369,7 +367,7 @@ async def create(self, exception=None ), SingleModel( - model_name="gpt-4-1106-preview", + llm_model="gpt-4-1106-preview", token_limit=128000, request=Openai, response=OpenaiResult, diff --git a/llmkira/sdk/endpoint/schema.py b/llmkira/sdk/endpoint/schema.py index 2e3134943..45a15da90 100644 --- a/llmkira/sdk/endpoint/schema.py +++ b/llmkira/sdk/endpoint/schema.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING import shortuuid -from pydantic import BaseModel, Field, PrivateAttr +from pydantic import ConfigDict, BaseModel, Field, PrivateAttr from .tee import Driver @@ -21,16 +21,13 @@ class LlmResult(BaseModel, ABC): LlmResult """ - id: str = Field(default_factory=lambda x: str(shortuuid.uuid()[0:8]), alias="request_id") + id: Optional[str] = Field(default=str(shortuuid.uuid())) object: str created: int model: str choices: list - usage: Any - - class Config: - arbitrary_types_allowed = True - extra = "allow" + usage: Any = None + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") @property def result_type(self): @@ -51,29 +48,30 @@ class LlmRequest(BaseModel, ABC): """ LlmRequest """ - - class Config: - arbitrary_types_allowed = True - extra = "allow" + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") _config: Driver = PrivateAttr() messages: list temperature: Optional[float] = 1 - top_p: Optional[float] + top_p: Optional[float] = None n: Optional[int] = 1 - stop: Optional[Union[str, List[str]]] - max_tokens: Optional[int] - presence_penalty: Optional[float] - frequency_penalty: Optional[float] - seed: Optional[int] + stop: Optional[Union[str, List[str]]] = None + max_tokens: Optional[int] = None + presence_penalty: Optional[float] = None + frequency_penalty: Optional[float] = None + seed: Optional[int] = None # 用于调试 - _echo: bool = Field(default=None) + __echo: bool = PrivateAttr(default=False) @property def config(self): return self._config + @property + def echo(self): + return self.__echo + @property def model(self): if self.config: @@ -112,7 +110,7 @@ def schema_map(self) -> dict: } def create_params(self): - _arg = self.dict( + _arg = self.model_dump( exclude_none=True, include=self.schema_map ) diff --git a/llmkira/sdk/endpoint/tee.py b/llmkira/sdk/endpoint/tee.py index 953294c7c..d9125c8d9 100644 --- a/llmkira/sdk/endpoint/tee.py +++ b/llmkira/sdk/endpoint/tee.py @@ -12,9 +12,10 @@ from typing import Optional from typing import TYPE_CHECKING -from pydantic import BaseSettings, Field, validator +from pydantic import field_validator, Field from ..error import ValidationError +from pydantic_settings import BaseSettings, SettingsConfigDict if TYPE_CHECKING: pass @@ -38,8 +39,8 @@ class Driver(BaseSettings): api_key: str = Field(default=None) org_id: Optional[str] = Field(default=None) model: str = Field(default="gpt-3.5-turbo-0613") - model_retrieve: str = Field(default="gpt-3.5-turbo-16k") - proxy_address: str = Field(None) + retrieve_model: str = Field(default="gpt-3.5-turbo-16k") + proxy_address: Optional[str] = Field(None) # TODO:AZURE API VERSION @@ -51,7 +52,7 @@ def detail(self): api_key = "****" + str(self.api_key)[-4:] return ( f"Endpoint: {self.endpoint}\nApi_key: {api_key}\n" - f"Org_id: {self.org_id}\nModel: {self.model}\nRetrieve_model: {self.model_retrieve}" + f"Org_id: {self.org_id}\nModel: {self.model}\nRetrieve_model: {self.retrieve_model}" ) @property @@ -79,7 +80,7 @@ def from_public_env(cls): proxy_address=openai_proxy ) - @validator("api_key") + @field_validator("api_key") def check_key(cls, v): if v: if len(str(v)) < 4: @@ -101,10 +102,4 @@ def uuid(self): """ _flag = self.api_key[-3:] return f"{_flag}:{sha1_encrypt(self.api_key)}" - - class Config: - env_file = '.env' - env_file_encoding = 'utf-8' - case_sensitive = True - arbitrary_types_allowed = True - extra = "allow" + model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', case_sensitive=True, arbitrary_types_allowed=True, extra="allow") diff --git a/llmkira/sdk/endpoint/tokenizer.py b/llmkira/sdk/endpoint/tokenizer.py index 35132b469..454e5192c 100644 --- a/llmkira/sdk/endpoint/tokenizer.py +++ b/llmkira/sdk/endpoint/tokenizer.py @@ -13,7 +13,7 @@ def _pydantic_type(_message): if isinstance(_message, BaseModel): - return _message.dict() + return _message.model_dump() return _message @@ -30,7 +30,10 @@ def num_tokens_from_messages(self, messages: List[Union[dict, BaseModel, Type[Ba class OpenaiTokenizer(BaseTokenizer): - def num_tokens_from_messages(self, messages: List[Union[dict, BaseModel, Type[BaseModel]]], model: str) -> int: + def num_tokens_from_messages(self, + messages: List[Union[dict, BaseModel, Type[BaseModel]]], + model: str + ) -> int: """Return the number of tokens used by a list of messages_box.""" if hasattr(messages, "request_final"): messages: "Message" @@ -61,7 +64,9 @@ def num_tokens_from_messages(self, messages: List[Union[dict, BaseModel, Type[Ba print("Warning: gpt-4 may update over time. Returning num tokens assuming gpt-4-0613.") return self.num_tokens_from_messages(messages, model="gpt-4-0613") else: - raise NotImplementedError( + tokens_per_message = 4 # every message follows <|start|>{role/name}\n{content}<|end|>\n + tokens_per_name = 1 # if there's a name, the role is omitted + logger.warning( f"""num_tokens_from_messages() is not implemented for model {model}.""" """:) If you use a no-openai model, """ """you can [one-api](https://github.com/songquanpeng/one-api) project handle token usage.""" diff --git a/llmkira/sdk/func_calling/schema.py b/llmkira/sdk/func_calling/schema.py index 0d482cb4f..6f8e1e354 100644 --- a/llmkira/sdk/func_calling/schema.py +++ b/llmkira/sdk/func_calling/schema.py @@ -8,7 +8,7 @@ from typing import TYPE_CHECKING from loguru import logger -from pydantic import BaseModel, Field, validator, root_validator +from pydantic import field_validator, BaseModel, Field, model_validator if TYPE_CHECKING: from ...task import TaskHeader @@ -43,15 +43,14 @@ def name(self): return self.function.name @final - @root_validator - def _check_conflict(cls, values): - # env_required and silent - if values["silent"] and values["env_required"]: + @model_validator(mode="after") + def _check_conflict(self): + if self.silent and self.env_required: raise ValueError("silent and env_required can not be True at the same time") - return values + return self @final - @validator("keywords", pre=True) + @field_validator("keywords", mode="before") def _check_keywords(cls, v): for i in v: if not isinstance(i, str): diff --git a/llmkira/sdk/memory/redis/__init__.py b/llmkira/sdk/memory/redis/__init__.py index 043264f12..cabfec71d 100644 --- a/llmkira/sdk/memory/redis/__init__.py +++ b/llmkira/sdk/memory/redis/__init__.py @@ -6,7 +6,8 @@ import redis from loguru import logger -from pydantic import BaseSettings, Field, root_validator +from pydantic import Field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict from llmkira.sdk.schema import Message, parse_message_dict from .utils import get_client @@ -14,16 +15,13 @@ class RedisChatMessageHistory(object): class RedisSettings(BaseSettings): - redis_url: str = Field("redis://localhost:6379/0", env="REDIS_DSN") + redis_url: str = Field("redis://localhost:6379/0", validation_alias="REDIS_DSN") redis_key_prefix: str = "llm_message_store_1:" + model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', extra="ignore") - class Config: - env_file = '.env' - env_file_encoding = 'utf-8' - - @root_validator - def redis_is_connected(cls, values): - redis_url = values.get("redis_url") + @model_validator(mode="after") + def redis_is_connected(self): + redis_url = self.redis_url try: get_client(redis_url=redis_url) except redis.exceptions.ConnectionError as error: @@ -31,7 +29,7 @@ def redis_is_connected(cls, values): raise ValueError("Could not connect to Redis") else: logger.info("Core:Created Redis client successfully") - return values + return self def __init__( self, @@ -72,7 +70,7 @@ def messages(self) -> List[BaseMessage]: # type: ignore def add_message(self, message: Message) -> None: """Append the message to the record in Redis""" - self.redis_client.lpush(self.key, message.json()) + self.redis_client.lpush(self.key, message.model_dump_json()) if self.ttl: self.redis_client.expire(self.key, self.ttl) diff --git a/llmkira/sdk/schema.py b/llmkira/sdk/schema.py index dad8114d8..fd4a7c5f7 100644 --- a/llmkira/sdk/schema.py +++ b/llmkira/sdk/schema.py @@ -13,8 +13,9 @@ from typing import TYPE_CHECKING import shortuuid +from docstring_parser import parse from loguru import logger -from pydantic import BaseModel, root_validator, Field, PrivateAttr +from pydantic import model_validator, BaseModel, Field, PrivateAttr from .error import ValidationError, CheckError from .utils import sync @@ -45,8 +46,8 @@ def pair(self): return self.file_name, self.file_data file_id: Optional[str] = Field(None, description="文件ID") - file_name: str = Field(None, description="文件名") - file_url: str = Field(None, description="文件URL") + file_name: Optional[str] = Field(None, description="文件名") + file_url: Optional[str] = Field(None, description="文件URL") caption: str = Field(default='', description="文件注释") bytes: int = Field(default=None, description="文件大小") created_by: str = Field(default=None, description="上传者") @@ -76,7 +77,7 @@ def file_prompt(self): FOR LLM """ _comment = '(' - for key, value in self.dict().items(): + for key, value in self.model_dump().items(): if value: _comment += f"{key}={value}," return f"[Attachment{_comment[:-1]})]" @@ -148,14 +149,14 @@ class BaseFunction(BaseModel): """ class FunctionExtra(BaseModel): - system_prompt: str = Field(None, description="系统提示") + system_prompt: Optional[str] = Field(None, description="系统提示") @classmethod def default(cls): return cls(system_prompt=None) _config: FunctionExtra = FunctionExtra.default() - name: str = Field(None, description="函数名称", regex=r"^[a-zA-Z0-9_]+$") + name: Optional[str] = Field(None, description="函数名称", pattern=r"^[a-zA-Z0-9_]+$") def update_config(self, config: FunctionExtra) -> "BaseFunction": self._config = config @@ -172,8 +173,7 @@ def config(self) -> FunctionExtra: def request_final(self, *, schema_model: str): - return self.copy( - include={"name"} + return self.model_copy( ) @@ -188,7 +188,7 @@ class Parameters(BaseModel): properties: dict = {} required: List[str] = Field(default=[], description="必填参数") - name: str = Field(None, description="函数名称", regex=r"^[a-zA-Z0-9_]+$") + name: Optional[str] = Field(None, description="函数名称", pattern=r"^[a-zA-Z0-9_]+$") description: Optional[str] = None parameters: Parameters = Parameters(type="object") @@ -200,19 +200,17 @@ def request_final(self, :param schema_model: 适配的模型 """ if schema_model.startswith("gpt-"): - return self.copy( - include={"name", "description", "parameters"} + return self.model_copy( ) elif schema_model.startswith("chatglm"): - return self.copy( - include={"name", "description", "parameters"} + return self.model_copy( ) else: raise CheckError(f"unknown model {schema_model}, cant classify model type") def add_property(self, property_name: str, - property_type: Literal["string", "integer", "number", "boolean", "object"], + property_type: Literal["string", "integer", "number", "boolean", "object", "array"], property_description: str, enum: Optional[tuple] = None, required: bool = False @@ -228,12 +226,41 @@ def add_property(self, if required: self.parameters.required.append(property_name) - def parse_schema_to_properties(self, schema: Type[BaseModel]): + @classmethod + def parse_pydantic_schema(cls, schema_model: Type[BaseModel]): """ 解析 pydantic 的 schema """ - self.parameters.properties = schema.schema()["properties"] - self.parameters.required = schema.schema()["required"] + schema = schema_model.model_json_schema() + docstring = parse(schema.__doc__ or "") + parameters = { + k: v for k, v in schema.items() if k not in ("title", "description") + } + for param in docstring.params: + name = param.arg_name + description = param.description + if (name in parameters["properties"]) and description: + if "description" not in parameters["properties"][name]: + parameters["properties"][name]["description"] = description + + parameters["required"] = sorted( + k for k, v in parameters["properties"].items() if "default" not in v + ) + + if "description" not in schema: + if docstring.short_description: + schema["description"] = docstring.short_description + else: + schema["description"] = ( + f"Correctly extracted `{cls.__name__}` with all " + f"the required parameters with correct types" + ) + + return { + "name": schema["title"], + "description": schema["description"], + "parameters": parameters, + } class FunctionCallCompletion(BaseModel): @@ -254,17 +281,11 @@ def request_final( schema_model ): if schema_model.startswith("gpt-"): - return self.copy( - include={"type", "function"}, - ) + return self elif schema_model.startswith("chatglm"): - return self.copy( - include={"type", "function"}, - ) + return self else: - return self.copy( - include={"type", "function"}, - ) + return self # raise CheckError(f"unknown model {schema_model}, cant classify model type") @@ -327,7 +348,7 @@ class ContentParts(BaseModel): class Image(BaseModel): url: str - detail: Optional[str] + detail: Optional[str] = None type: str image_url: Optional[str] @@ -396,7 +417,7 @@ def request_final(self, class SystemMessage(Message): role: str = Field(default="system") content: str - name: Optional[str] = Field(default=None, description="speaker_name", regex=r"^[a-zA-Z0-9_]+$") + name: Optional[str] = Field(default=None, description="speaker_name", pattern=r"^[a-zA-Z0-9_]+$") def request_final(self, *, @@ -408,7 +429,7 @@ def request_final(self, class UserMessage(Message): role: str = Field(default="user") content: Union[str, List[ContentParts], List[dict]] - name: Optional[str] = Field(default=None, description="speaker_name", regex=r"^[a-zA-Z0-9_]+$") + name: Optional[str] = Field(default=None, description="speaker_name", pattern=r"^[a-zA-Z0-9_]+$") @property def fold(self) -> "Message": @@ -421,11 +442,10 @@ def fold(self) -> "Message": f"""\ntimestamp={self._meta.datatime}""" f"""\ndescription={self.content[:20] + "..."})""" ) - return self.copy( + return self.model_copy( update={ "content": metadata_str - }, - include={"role", "content", "name"} + } ) def request_final(self, @@ -471,21 +491,21 @@ def request_final(self, class AssistantMessage(Message): role: str = Field(default="assistant") content: Union[None, str] = Field(default='', description="assistant content") - name: Optional[str] = Field(default=None, description="speaker_name", regex=r"^[a-zA-Z0-9_]+$") + name: Optional[str] = Field(default=None, description="speaker_name", pattern=r"^[a-zA-Z0-9_]+$") tool_calls: Optional[List[ToolCallCompletion]] = Field(default=None, description="tool calls") """a array of tools, for result""" function_call: Optional[FunctionCallCompletion] = Field(default=None, description="Deprecated") """Deprecated by openai ,for result""" - @root_validator() - def deprecate_validator(cls, values): - if values.get("tool_calls") and values.get("function_call"): + @model_validator(mode="after") + def deprecate_validator(self): + if self.tool_calls and self.function_call: raise ValidationError("sdk param validator:tool_calls and function_call cannot both be provided") - if values.get("function_call"): + if self.function_call: logger.warning("sdk param validator:function_call is deprecated") - if values.get("content") is None: - values["content"] = "" - return values + if self.content is None: + self.content = "" + return self def get_executor_batch(self) -> List[TaskBatch]: """ @@ -536,12 +556,12 @@ class FunctionMessage(Message): content: str name: str - @root_validator() - def function_validator(cls, values): + @model_validator(mode="after") + def function_validator(self): logger.warning("Function Message is deprecated by openai") - if values.get("role") == "function" and not values.get("name"): + if self.role == "function" and not self.name: raise ValidationError("sdk param validator:name must be specified when role is function") - return values + return self def request_final(self, *, @@ -586,15 +606,15 @@ def parse_message_dict(item: dict): return None try: if role == "assistant": - _message = AssistantMessage.parse_obj(item) + _message = AssistantMessage.model_validate(item) elif role == "user": - _message = UserMessage.parse_obj(item) + _message = UserMessage.model_validate(item) elif role == "system": - _message = SystemMessage.parse_obj(item) + _message = SystemMessage.model_validate(item) elif role == "tool": - _message = ToolMessage.parse_obj(item) + _message = ToolMessage.model_validate(item) elif role == "function": - _message = FunctionMessage.parse_obj(item) + _message = FunctionMessage.model_validate(item) else: raise CheckError(f"unknown message type {role}") except Exception as e: @@ -613,7 +633,7 @@ def standardise_for_request( 标准化转换,供请求使用 """ if isinstance(message, dict): - message = Message.parse_obj(message) + message = Message.model_validate(message) if hasattr(message, "message"): return message.request_final(schema_model=schema_model) else: diff --git a/llmkira/sdk/utils.py b/llmkira/sdk/utils.py index ae3d95846..c663a3792 100644 --- a/llmkira/sdk/utils.py +++ b/llmkira/sdk/utils.py @@ -7,6 +7,7 @@ """ import asyncio import hashlib +from bisect import bisect_left from typing import Coroutine, Dict, List import aiohttp @@ -73,3 +74,16 @@ async def download_file(url, timeout=None, size_limit=None, headers=None): contents = await response.read() return contents + + +def prefix_search(wordlist, prefix): + """ + 在有序列表中二分查找前缀 + :param wordlist: 有序列表 + :param prefix: 前缀 + """ + try: + index = bisect_left(wordlist, prefix) + return wordlist[index].startswith(prefix) + except IndexError: + return False diff --git a/llmkira/sender/rss/rss.py b/llmkira/sender/rss/rss.py index 6bf5d7fbd..82b3aa9e9 100644 --- a/llmkira/sender/rss/rss.py +++ b/llmkira/sender/rss/rss.py @@ -20,8 +20,8 @@ from telebot import formatting from telebot.formatting import escape_markdown -from ...sdk.cache.redis import cache from ..schema import Runner +from ...sdk.cache.redis import cache from ...task import Task, TaskHeader __sender__ = "rss" @@ -138,7 +138,7 @@ def get_feed(self): async def re_init(self, update: Update) -> (str, list): _entry = list(update.entry.values())[:1] - await cache.set_data(key=self.db_key, value=update.json(), timeout=60 * 60 * 60 * 7) + await cache.set_data(key=self.db_key, value=update.model_dump_json(), timeout=60 * 60 * 60 * 7) return update.title, _entry async def update(self, cache_, update_, keys): @@ -147,7 +147,7 @@ async def update(self, cache_, update_, keys): # copy cache_.entry[key] = update_.entry[key] _return.append(update_.entry[key]) - await cache.set_data(key=self.db_key, value=cache_.json(), timeout=60 * 60 * 60 * 7) + await cache.set_data(key=self.db_key, value=cache_.model_dump_json(), timeout=60 * 60 * 60 * 7) return update_.title, _return async def get_updates(self): @@ -157,7 +157,7 @@ async def get_updates(self): if not _data: return await self.re_init(_load) assert isinstance(_data, dict), "wrong rss data" - _cache = self.Update.parse_obj(_data) + _cache = self.Update.model_validate(_data) # 验证是否全部不一样 _old = list(_cache.entry.keys()) diff --git a/llmkira/sender/slack/__init__.py b/llmkira/sender/slack/__init__.py index c9694515e..c3a6e0bca 100644 --- a/llmkira/sender/slack/__init__.py +++ b/llmkira/sender/slack/__init__.py @@ -166,7 +166,7 @@ async def create_task(event_: SlackMessageEvent, funtion_enable: bool = False): @bot.command(command='/clear_endpoint') async def listen_clear_endpoint_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() # _cmd, _arg = parse_command(command=message.text) _tips = "🪄 Done" @@ -177,7 +177,7 @@ async def listen_clear_endpoint_command(ack: AsyncAck, respond: AsyncRespond, co @bot.command(command='/set_endpoint') async def listen_set_endpoint_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() if not command.text: return @@ -202,7 +202,7 @@ async def listen_set_endpoint_command(ack: AsyncAck, respond: AsyncRespond, comm return await respond( text=formatting.format_text( formatting.mbold(f"🪄 Failed: {e}", escape=False), - formatting.mitalic("Format: /set_endpoint ##"), + formatting.mitalic("Format: /set_endpoint ##"), formatting.mitalic(f"Model Name: {UserControl.get_model()}"), separator="\n" ) @@ -218,7 +218,7 @@ async def listen_set_endpoint_command(ack: AsyncAck, respond: AsyncRespond, comm @bot.command(command='/func_ban') async def listen_func_ban_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() if not command.text: return @@ -247,7 +247,7 @@ async def listen_func_ban_command(ack: AsyncAck, respond: AsyncRespond, command) @bot.command(command='/func_unban') async def listen_func_unban_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() if not command.text: return @@ -276,7 +276,7 @@ async def listen_func_unban_command(ack: AsyncAck, respond: AsyncRespond, comman @bot.command(command='/token') async def listen_token_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() if not command.text: return @@ -304,7 +304,7 @@ async def listen_token_command(ack: AsyncAck, respond: AsyncRespond, command): @bot.command(command='/token_clear') async def listen_unbind_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() try: token = await UserControl.set_token( @@ -329,7 +329,7 @@ async def listen_unbind_command(ack: AsyncAck, respond: AsyncRespond, command): @bot.command(command='/bind') async def listen_bind_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() if not command.text: return @@ -357,7 +357,7 @@ async def listen_bind_command(ack: AsyncAck, respond: AsyncRespond, command): @bot.command(command='/unbind') async def listen_unbind_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() if not command.text: return @@ -385,7 +385,7 @@ async def listen_unbind_command(ack: AsyncAck, respond: AsyncRespond, command): @bot.command(command='/env') async def listen_env_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() if not command.text: return @@ -412,7 +412,7 @@ async def listen_env_command(ack: AsyncAck, respond: AsyncRespond, command): @bot.command(command='/clear') async def listen_clear_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() RedisChatMessageHistory(session_id=f"{__sender__}:{command.user_id}", ttl=60 * 60 * 1).clear() return await respond( @@ -424,7 +424,7 @@ async def listen_clear_command(ack: AsyncAck, respond: AsyncRespond, command): @bot.command(command='/help') async def listen_help_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() await respond( text=formatting.format_text( @@ -436,7 +436,7 @@ async def listen_help_command(ack: AsyncAck, respond: AsyncRespond, command): @bot.command(command='/tool') async def listen_tool_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() _tool = ToolRegister().functions _paper = [[c.name, c.description] for name, c in _tool.items()] @@ -467,7 +467,7 @@ async def auth_chain(uuid, user_id): @bot.command(command='/auth') async def listen_auth_command(ack: AsyncAck, respond: AsyncRespond, command): - command: SlashCommand = SlashCommand.parse_obj(command) + command: SlashCommand = SlashCommand.model_validate(command) await ack() if not command.text: return await respond( @@ -489,7 +489,7 @@ async def validate_join(event_: SlackMessageEvent): _res = await self.bot.client.conversations_info( channel=event_.channel ) - _channel: SlackChannelInfo = SlackChannelInfo.parse_obj(_res.get("channel", {})) + _channel: SlackChannelInfo = SlackChannelInfo.model_validate(_res.get("channel", {})) if not _channel.is_member: raise Exception("Not in channel") except Exception as e: @@ -553,16 +553,16 @@ async def deal_group(event_: SlackMessageEvent): return await create_task(event_, funtion_enable=__default_function_enable__) @bot.event("message") - async def listen_im(event, logger): + async def listen_im(_event, _logger): """ 自动响应私聊消息 """ - logger.info(event) - event_: SlackMessageEvent = SlackMessageEvent.parse_obj(event) + _logger.info(event) + event_: SlackMessageEvent = SlackMessageEvent.model_validate(_event) # 校验消息是否过期 if int(float(event_.event_ts)) < int(time.time()) - 60 * 60 * 5: - logger.warning(f"slack: message expired {event_.event_ts}") + _logger.warning(f"slack: message expired {event_.event_ts}") return if not event_.text: return None diff --git a/llmkira/sender/slack/event.py b/llmkira/sender/slack/event.py index 5ac7abf8f..4317f87c2 100644 --- a/llmkira/sender/slack/event.py +++ b/llmkira/sender/slack/event.py @@ -4,7 +4,7 @@ # @File : event.py # @Software: PyCharm -from pydantic import BaseModel +from pydantic import ConfigDict, BaseModel def help_message(): @@ -73,9 +73,7 @@ class SlashCommand(BaseModel): response_url: str = None trigger_id: str = None api_app_id: str = None - - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") class SlackChannelInfo(BaseModel): @@ -107,6 +105,4 @@ class SlackChannelInfo(BaseModel): topic: dict = None purpose: dict = None previous_names: list = None - - class Config: - extra = "allow" + model_config = ConfigDict(extra="allow") diff --git a/llmkira/sender/slack/schema.py b/llmkira/sender/slack/schema.py index 8b4d70fc1..bd4023b11 100644 --- a/llmkira/sender/slack/schema.py +++ b/llmkira/sender/slack/schema.py @@ -5,7 +5,7 @@ # @Software: PyCharm from typing import List -from pydantic import BaseModel, Field +from pydantic import ConfigDict, BaseModel, Field class SlackMessageEvent(BaseModel): @@ -14,56 +14,53 @@ class SlackMessageEvent(BaseModel): """ class SlackFile(BaseModel): - id: str = Field(None, description="id") + id: Optional[str] = Field(None, description="id") created: int = Field(None, description="created") timestamp: int = Field(None, description="timestamp") - name: str = Field(None, description="name") - title: str = Field(None, description="title") - mimetype: str = Field(None, description="mimetype") - filetype: str = Field(None, description="filetype") - pretty_type: str = Field(None, description="pretty_type") - user: str = Field(None, description="user") - user_team: str = Field(None, description="user_team") + name: Optional[str] = Field(None, description="name") + title: Optional[str] = Field(None, description="title") + mimetype: Optional[str] = Field(None, description="mimetype") + filetype: Optional[str] = Field(None, description="filetype") + pretty_type: Optional[str] = Field(None, description="pretty_type") + user: Optional[str] = Field(None, description="user") + user_team: Optional[str] = Field(None, description="user_team") editable: bool = Field(None, description="editable") size: int = Field(None, description="size") - mode: str = Field(None, description="mode") + mode: Optional[str] = Field(None, description="mode") is_external: bool = Field(None, description="is_external") - external_type: str = Field(None, description="external_type") + external_type: Optional[str] = Field(None, description="external_type") is_public: bool = Field(None, description="is_public") public_url_shared: bool = Field(None, description="public_url_shared") display_as_bot: bool = Field(None, description="display_as_bot") - username: str = Field(None, description="username") - url_private: str = Field(None, description="url_private") - url_private_download: str = Field(None, description="url_private_download") - media_display_type: str = Field(None, description="media_display_type") - thumb_64: str = Field(None, description="thumb_64") - thumb_80: str = Field(None, description="thumb_80") - thumb_360: str = Field(None, description="thumb_360") + username: Optional[str] = Field(None, description="username") + url_private: Optional[str] = Field(None, description="url_private") + url_private_download: Optional[str] = Field(None, description="url_private_download") + media_display_type: Optional[str] = Field(None, description="media_display_type") + thumb_64: Optional[str] = Field(None, description="thumb_64") + thumb_80: Optional[str] = Field(None, description="thumb_80") + thumb_360: Optional[str] = Field(None, description="thumb_360") thumb_360_w: int = Field(None, description="thumb_360_w") thumb_360_h: int = Field(None, description="thumb_360_h") - thumb_160: str = Field(None, description="thumb_160") + thumb_160: Optional[str] = Field(None, description="thumb_160") original_w: int = Field(None, description="original_w") original_h: int = Field(None, description="original_h") - thumb_tiny: str = Field(None, description="thumb_tiny") - permalink: str = Field(None, description="permalink") - permalink_public: str = Field(None, description="permalink_public") + thumb_tiny: Optional[str] = Field(None, description="thumb_tiny") + permalink: Optional[str] = Field(None, description="permalink") + permalink_public: Optional[str] = Field(None, description="permalink_public") has_rich_preview: bool = Field(None, description="has_rich_preview") - file_access: str = Field(None, description="file_access") + file_access: Optional[str] = Field(None, description="file_access") - client_msg_id: str = Field(None, description="client_msg_id") - type: str = Field(None, description="type") - text: str = Field(None, description="text") - user: str = Field(None, description="user") - ts: str = Field(None, description="ts") + client_msg_id: Optional[str] = Field(None, description="client_msg_id") + type: Optional[str] = Field(None, description="type") + text: Optional[str] = Field(None, description="text") + user: Optional[str] = Field(None, description="user") + ts: Optional[str] = Field(None, description="ts") blocks: List[dict] = Field([], description="blocks") - team: str = Field(None, description="team") - thread_ts: str = Field(None, description="thread_ts") - parent_user_id: str = Field(None, description="parent_user_id") - channel: str = Field(None, description="channel") - event_ts: str = Field(None, description="event_ts") - channel_type: str = Field(None, description="channel_type") + team: Optional[str] = Field(None, description="team") + thread_ts: Optional[str] = Field(None, description="thread_ts") + parent_user_id: Optional[str] = Field(None, description="parent_user_id") + channel: Optional[str] = Field(None, description="channel") + event_ts: Optional[str] = Field(None, description="event_ts") + channel_type: Optional[str] = Field(None, description="channel_type") files: List["SlackFile"] = Field(default=[], description="files") - - class Config: - arbitrary_types_allowed = True - extra = "allow" + model_config = ConfigDict(arbitrary_types_allowed=True, extra="allow") diff --git a/llmkira/setting/__init__.py b/llmkira/setting/__init__.py index 880c74ef1..6d1d521a0 100644 --- a/llmkira/setting/__init__.py +++ b/llmkira/setting/__init__.py @@ -3,7 +3,7 @@ # @Author : sudoskys # @File : __init__.py.py # @Software: PyCharm -from pydantic import Field, BaseModel +from pydantic import ConfigDict, Field, BaseModel from .discord import BotSetting as DiscordSetting from .kook import BotSetting as KookSetting @@ -19,9 +19,7 @@ class StartSetting(BaseModel): kook: bool = Field(default=False) slack: bool = Field(default=False) telegram: bool = Field(default=False) - - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) @classmethod def from_subdir(cls): diff --git a/llmkira/setting/discord.py b/llmkira/setting/discord.py index a955da26e..1a5051395 100644 --- a/llmkira/setting/discord.py +++ b/llmkira/setting/discord.py @@ -3,37 +3,31 @@ # @Author : sudoskys # @File : discord.py # @Software: PyCharm +from typing import Optional + from dotenv import load_dotenv from loguru import logger -from pydantic import BaseSettings, Field, validator, root_validator +from pydantic import Field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict class DiscordBot(BaseSettings): """ 代理设置 """ - token: str = Field(None, env='DISCORD_BOT_TOKEN') - prefix: str = Field("/", env="DISCORD_BOT_PREFIX") - proxy_address: str = Field(None, env="DISCORD_BOT_PROXY_ADDRESS") # "all://127.0.0.1:7890" - bot_id: str = Field(None) - - class Config: - env_file = '.env' - env_file_encoding = 'utf-8' - - @validator('token') - def bot_token_validator(cls, v): - if v is None: - logger.warning(f"\n🍀Check:DiscordBot token is empty") - else: - logger.success(f"🍀Check:DiscordBot token ready") - return v - - @root_validator - def bot_setting_validator(cls, values): - if values['proxy_address']: - logger.success(f"DiscordBot proxy was set to {values['proxy_address']}") - return values + token: Optional[str] = Field(None, validation_alias='DISCORD_BOT_TOKEN', strict=True) + prefix: Optional[str] = Field("/", validation_alias="DISCORD_BOT_PREFIX") + proxy_address: Optional[str] = Field(None, validation_alias="DISCORD_BOT_PROXY_ADDRESS") # "all://127.0.0.1:7890" + bot_id: Optional[str] = Field(None) + model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', extra="ignore") + + @model_validator(mode='after') + def bot_setting_validator(self): + if self.token is None: + logger.warning(f"\n🍀DiscordBot token is empty") + if self.proxy_address: + logger.success(f"DiscordBot proxy was set to {self.proxy_address}") + return self @property def available(self): diff --git a/llmkira/setting/kook.py b/llmkira/setting/kook.py index 9ce2e09aa..47038e520 100644 --- a/llmkira/setting/kook.py +++ b/llmkira/setting/kook.py @@ -3,34 +3,26 @@ # @Author : sudoskys # @File : kook.py # @Software: PyCharm +from typing import Optional + from dotenv import load_dotenv from loguru import logger -from pydantic import BaseSettings, Field, validator, root_validator +from pydantic import Field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict class KookBot(BaseSettings): """ 代理设置 """ - token: str = Field(None, env='KOOK_BOT_TOKEN') - - # proxy_address: str = Field(None, env="DISCORD_BOT_PROXY_ADDRESS") # "all://127.0.0.1:7890" - - class Config: - env_file = '.env' - env_file_encoding = 'utf-8' - - @validator('token') - def token_validator(cls, v): - if v is None: - logger.warning(f"\n🍀Check:KookBot token is empty") - else: - logger.success(f"🍀Check:KookBot token ready") - return v - - @root_validator - def bot_setting_validator(cls, values): - return values + token: Optional[str] = Field(None, validation_alias='KOOK_BOT_TOKEN') + model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', extra="ignore") + + @model_validator(mode='after') + def bot_setting_validator(self): + if self.token is None: + logger.warning(f"\n🍀KookBot token is empty") + return self @property def available(self): diff --git a/llmkira/setting/slack.py b/llmkira/setting/slack.py index ce62a180c..9d8c0a7e0 100644 --- a/llmkira/setting/slack.py +++ b/llmkira/setting/slack.py @@ -3,46 +3,46 @@ # @Author : sudoskys # @File : slack.py # @Software: PyCharm +from typing import Optional + from dotenv import load_dotenv from loguru import logger -from pydantic import BaseSettings, Field, root_validator +from pydantic import Field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict class SlackBot(BaseSettings): """ 代理设置 """ - app_token: str = Field(None, env='SLACK_APP_TOKEN') + app_token: Optional[str] = Field(None, validation_alias='SLACK_APP_TOKEN') # https://api.slack.com/apps - bot_token: str = Field(None, env='SLACK_BOT_TOKEN') + bot_token: Optional[str] = Field(None, validation_alias='SLACK_BOT_TOKEN') # https://api.slack.com/apps/XXXX/oauth? - secret: str = Field(None, env='SLACK_SIGNING_SECRET') + secret: Optional[str] = Field(None, validation_alias='SLACK_SIGNING_SECRET') # https://api.slack.com/authentication/verifying-requests-from-slack#signing_secrets_admin_page - proxy_address: str = Field(None, env="SLACK_BOT_PROXY_ADDRESS") # "all://127.0.0.1:7890" - bot_id: str = Field(None) - bot_username: str = Field(None) - - class Config: - env_file = '.env' - env_file_encoding = 'utf-8' + proxy_address: Optional[str] = Field(None, validation_alias="SLACK_BOT_PROXY_ADDRESS") # "all://127.0.0.1:7890" + bot_id: Optional[str] = Field(None) + bot_username: Optional[str] = Field(None) + model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', extra="ignore") - @root_validator - def bot_setting_validator(cls, values): + @model_validator(mode='after') + def bot_setting_validator(self): try: - if values['app_token'] is None: + if self.app_token is None: raise ValueError("\n🍀Check:SlackBot app_token is empty") - if values['bot_token'] is None: + if self.bot_token is None: raise ValueError("\n🍀Check:SlackBot bot_token is empty") - if values['secret'] is None: + if self.secret is None: raise ValueError("\n🍀Check:SlackBot secret is empty") except Exception as e: logger.warning(e) else: logger.success(f"🍀Check:SlackBot token ready") - return values + return self @property def available(self): diff --git a/llmkira/setting/telegram.py b/llmkira/setting/telegram.py index ec07ff9ff..c19e1c7fa 100644 --- a/llmkira/setting/telegram.py +++ b/llmkira/setting/telegram.py @@ -3,49 +3,49 @@ # @Author : sudoskys # @File : telegram.py # @Software: PyCharm +from typing import Optional + from dotenv import load_dotenv from loguru import logger -from pydantic import BaseSettings, Field, root_validator +from pydantic import Field, model_validator +from pydantic_settings import BaseSettings, SettingsConfigDict class TelegramBot(BaseSettings): """ 代理设置 """ - token: str = Field(None, env='TELEGRAM_BOT_TOKEN') - proxy_address: str = Field(None, env="TELEGRAM_BOT_PROXY_ADDRESS") # "all://127.0.0.1:7890" - bot_link: str = Field(None, env='TELEGRAM_BOT_LINK') - bot_id: str = Field(None, env="TELEGRAM_BOT_ID") - bot_username: str = Field(None, env="TELEGRAM_BOT_USERNAME") - - class Config: - env_file = '.env' - env_file_encoding = 'utf-8' + token: Optional[str] = Field(None, validation_alias='TELEGRAM_BOT_TOKEN') + proxy_address: Optional[str] = Field(None, validation_alias="TELEGRAM_BOT_PROXY_ADDRESS") # "all://127.0.0.1:7890" + bot_link: Optional[str] = Field(None, validation_alias='TELEGRAM_BOT_LINK') + bot_id: Optional[str] = Field(None, validation_alias="TELEGRAM_BOT_ID") + bot_username: Optional[str] = Field(None, validation_alias="TELEGRAM_BOT_USERNAME") + model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', extra="ignore") - @root_validator - def bot_validator(cls, values): - if values['proxy_address']: - logger.success(f"TelegramBot proxy was set to {values['proxy_address']}") - if values.get('token') is None: + @model_validator(mode='after') + def bot_validator(self): + if self.proxy_address: + logger.success(f"TelegramBot proxy was set to {self.proxy_address}") + if self.token is None: logger.warning(f"\n🍀Check:Telegrambot token is empty") - if values.get('bot_id') is None and values.get('token', None): + if self.bot_id is None and self.token: try: from telebot import TeleBot # 创建 Bot - if values['proxy_address'] is not None: + if self.proxy_address is not None: from telebot import apihelper - if "socks5://" in values['proxy_address']: - values['proxy_address'] = values['proxy_address'].replace("socks5://", "socks5h://") - apihelper.proxy = {'https': values['proxy_address']} - _bot = TeleBot(token=values.get('token')).get_me() - values['bot_id'] = _bot.id - values['bot_username'] = _bot.username - values['bot_link'] = f"https://t.me/{values['bot_username']}" + if "socks5://" in self.proxy_address: + self.proxy_address = self.proxy_address.replace("socks5://", "socks5h://") + apihelper.proxy = {'https': self.proxy_address} + _bot = TeleBot(token=self.token).get_me() + self.bot_id = str(_bot.id) + self.bot_username = _bot.username + self.bot_link = f"https://t.me/{self.bot_username}" except Exception as e: logger.error(f"\n🍀Check:Telegrambot token is empty:{e}") else: - logger.success(f"🍀Check:TelegramBot connect success: {values.get('bot_username')}") - return values + logger.success(f"🍀Check:TelegramBot connect success: {self.bot_username}") + return self @property def available(self): diff --git a/llmkira/task/__init__.py b/llmkira/task/__init__.py index f3d10d92c..3a59db16e 100644 --- a/llmkira/task/__init__.py +++ b/llmkira/task/__init__.py @@ -42,7 +42,7 @@ async def send_task(self, task: TaskHeader): # await channel.initialize(timeout=2000) # Creating a message message = Message( - body=task.json().encode("utf-8"), + body=task.model_dump_json().encode("utf-8"), delivery_mode=DeliveryMode.PERSISTENT, expiration=EXPIRATION_SECOND ) diff --git a/llmkira/task/schema.py b/llmkira/task/schema.py index 7b252bcaf..414fdff30 100644 --- a/llmkira/task/schema.py +++ b/llmkira/task/schema.py @@ -12,7 +12,8 @@ import orjson from dotenv import load_dotenv from loguru import logger -from pydantic import BaseSettings, Field, BaseModel, root_validator, validator +from pydantic import model_validator, ConfigDict, Field, BaseModel, PrivateAttr +from pydantic_settings import BaseSettings, SettingsConfigDict from telebot import types from llmkira.schema import RawMessage @@ -33,46 +34,33 @@ class RabbitMQ(BaseSettings): """ 代理设置 """ - amqp_dsn: str = Field("amqp://admin:8a8a8a@localhost:5672", env='AMQP_DSN') - _verify_status: bool = Field(False, env='VERIFY_STATUS') + amqp_dsn: str = Field("amqp://admin:8a8a8a@localhost:5672", validation_alias='AMQP_DSN') + _verify_status: bool = PrivateAttr(default=False) + model_config = SettingsConfigDict(env_file='.env', env_file_encoding='utf-8', extra="ignore") - class Config: - env_file = '.env' - env_file_encoding = 'utf-8' - - @root_validator() - def is_connect(cls, values): + @model_validator(mode='after') + def is_connect(self): import aio_pika try: sync(aio_pika.connect_robust( - values['amqp_dsn'] + self.amqp_dsn )) except Exception as e: - values['_verify_status'] = False - logger.error(f'\n⚠️ RabbitMQ DISCONNECT, pls set AMQP_DSN in .env\n--error {e} --dsn {values["amqp_dsn"]}') + self._verify_status = False + logger.error(f'\n⚠️ RabbitMQ DISCONNECT, pls set AMQP_DSN in .env\n--error {e} --dsn {self.amqp_dsn}') else: - values['_verify_status'] = True + self._verify_status = True logger.success(f"RabbitMQ connect success") - if values['amqp_dsn'] == "amqp://admin:8a8a8a@localhost:5672": + if self.amqp_dsn == "amqp://admin:8a8a8a@localhost:5672": logger.warning(f"\n⚠️ You are using the default RabbitMQ password") - return values + return self - def check_connection(self, values): - import aio_pika - try: - sync(aio_pika.connect_robust( - self.amqp_dsn - )) - except Exception as e: - logger.warning('RabbitMQ DISCONNECT, pls set AMQP_DSN in ENV') - raise ValueError(f'RabbitMQ connect failed {e}') - else: - logger.success(f"RabbitMQ connect success") - return values + @property + def available(self): + return self._verify_status @property def task_server(self): - return self.amqp_dsn @@ -84,25 +72,22 @@ class TaskHeader(BaseModel): """ 任务链节点 """ - - class Config: - json_loads = orjson.loads - json_dumps = orjson_dumps + model_config = ConfigDict() class Meta(BaseModel): class Callback(BaseModel): function_response: str = Field("empty response", description="工具响应内容") - name: str = Field(None, description="功能名称", regex=r"^[a-zA-Z0-9_]+$") + name: str = Field(None, description="功能名称", pattern=r"^[a-zA-Z0-9_]+$") tool_call_id: Optional[str] = Field(None, description="工具调用ID") - @root_validator() - def check(cls, values): + @model_validator(mode="after") + def check(self): """ 检查回写消息 """ - if not values.get("tool_call_id") and not values.get("name"): + if not self.tool_call_id and not self.name: raise ValueError("tool_call_id or name must be set") - return values + return self @classmethod def create(cls, @@ -143,7 +128,7 @@ def get_function_message(self) -> Union[FunctionMessage]: ) plan_chain_pending: List[TaskBatch] = Field(default=[], description="待完成的节点") plan_chain_length: int = Field(default=0, description="节点长度") - plan_chain_complete: bool = Field(False, description="是否完成此集群") + plan_chain_complete: Optional[bool] = Field(False, description="是否完成此集群") """功能状态与缓存""" function_enable: bool = Field(False, description="功能开关") @@ -153,7 +138,7 @@ def get_function_message(self) -> Union[FunctionMessage]: """携带插件的写回结果""" write_back: bool = Field(False, description="写回消息") callback: List[Callback] = Field( - default=None, + default=[], description="用于回写,插件返回的消息头,标识 function 的名字" ) @@ -169,34 +154,31 @@ def get_function_message(self) -> Union[FunctionMessage]: run_step_limit: int = Field(4, description="函数集群计数器上限") """函数中枢的依赖变量""" - verify_uuid: str = Field(None, description="认证链的UUID,根据此UUID和 Map 可以确定哪个需要执行") + verify_uuid: Optional[str] = Field(None, description="认证链的UUID,根据此UUID和 Map 可以确定哪个需要执行") verify_map: Dict[str, TaskBatch] = Field({}, description="函数节点的认证信息,经携带认证重发后可通过") llm_result: Any = Field(None, description="存储任务的衍生信息源") - llm_type: str = Field(None, description="存储任务的衍生信息源类型") + llm_type: Optional[str] = Field(None, description="存储任务的衍生信息源类型") extra_args: dict = Field({}, description="提供额外参数") - @validator("llm_result") - def validate_llm_result(cls, val): - if isinstance(val, dict): - if not val.get("model"): + @model_validator(mode="after") + def check(self): + if isinstance(self.llm_result, dict): + if not self.llm_result.get("model"): raise TypeError("Invalid llm_result") - return val - @root_validator() - def check(cls, values): - if not any([values["callback_forward"], values["callback_forward_reprocess"], values["direct_reply"]]): - if values["write_back"]: + if not any([self.callback_forward, self.callback_forward_reprocess, self.direct_reply]): + if self.write_back: logger.warning("you shouldn*t write back without callback_forward or direct_reply") - values["write_back"] = False + self.write_back = False # If it is the root node, it cannot be written back. # Because the message posted by the user is always the root node. # Writing back will cause duplicate messages. # Because the middleware will write the message back - if values["sign_as"][0] == 0 and values["write_back"]: + if self.sign_as[0] == 0 and self.write_back: logger.warning("root node shouldn*t write back") - values["write_back"] = False - return values + self.write_back = False + return self @classmethod def from_root(cls, release_chain, function_enable, platform: str = "default", **kwargs): @@ -221,7 +203,7 @@ def pack_loop( verify_map = self.verify_map if not plan_chain_pending: raise ValueError("plan_chain_pending can't be empty") - _new = self.copy(deep=True) + _new = self.model_copy(deep=True) _new.plan_chain_pending = plan_chain_pending _new.plan_chain_length = len(plan_chain_pending) _new.verify_map = verify_map @@ -275,9 +257,9 @@ def is_complete(self) -> bool: """ return self.plan_chain_complete - class Config: - extra = "ignore" - arbitrary_types_allowed = True + model_config = ConfigDict(extra="ignore", + arbitrary_types_allowed=True + ) def child(self, name) -> "TaskHeader.Meta": """ @@ -285,7 +267,7 @@ def child(self, name) -> "TaskHeader.Meta": """ self.sign_as = (self.sign_as[0] + 1, "child", name) self.run_step_already += 1 - return self.copy(deep=True) + return self.model_copy(deep=True) def chain(self, name, @@ -302,7 +284,7 @@ def chain(self, self.direct_reply = False self.write_back = write_back self.release_chain = release_chain - return self.copy(deep=True) + return self.model_copy(deep=True) def reply_direct(self, *, @@ -410,12 +392,12 @@ class Location(BaseModel): thread_id: Optional[Union[str, int]] = Field(None, description="channel id/Telegram thread") message_id: Optional[Union[str, int]] = Field(None, description="message id") - @root_validator() - def to_string(cls, values): - for key in values: - if isinstance(values[key], int): - values[key] = str(values[key]) - return values + @model_validator(mode="after") + def to_string(self): + for key in ["user_id", "chat_id", "thread_id", "message_id"]: + if isinstance(getattr(self, key), int): + setattr(self, key, str(getattr(self, key))) + return self @property def uid(self): @@ -772,7 +754,7 @@ def _convert(_message: "SlackMessageEvent") -> Optional[RawMessage]: chat_id=chat_id, thread_id=thread_id, text=text if text else f"(empty message)", - created_at=created_at + created_at=str(created_at) ) deliver_message_list: List[RawMessage] = [_convert(msg) for msg in deliver_back_message]