Skip to content

Commit

Permalink
✨ chore(poetry.lock): update pydantic version ^2.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
sudoskys committed Nov 12, 2023
1 parent 7e18b30 commit 27a1d14
Show file tree
Hide file tree
Showing 42 changed files with 513 additions and 517 deletions.
6 changes: 5 additions & 1 deletion llmkira/error.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,15 @@
# @File : error.py
# @Software: PyCharm


class ReplyNeededError(Exception):
"""
Raised a error that need reply
"""
pass

def __init__(self, message: str = None, *args):
# 拦截 url 信息
super().__init__(message, *args)


# 更安全的 format
Expand Down
8 changes: 4 additions & 4 deletions llmkira/extra/plugins/_finish.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
from pydantic import ConfigDict

__package__name__ = "llmkira.extra.plugins.finish"
__plugin_name__ = "finish_conversation"
__openapi_version__ = "20231111"
Expand Down Expand Up @@ -31,9 +33,7 @@

class Finish(BaseModel):
comment: str = Field(default=":)", description="end with a question or a comment.(__language: $context)")

class Config:
extra = "allow"
model_config = ConfigDict(extra="allow")


class FinishTool(BaseTool):
Expand Down Expand Up @@ -99,7 +99,7 @@ async def run(self,
"""
处理message,返回message
"""
_set = Finish.parse_obj(arg)
_set = Finish.model_validate(arg)
# META
_meta = task.task_meta.reply_message(
plugin_name=__plugin_name__,
Expand Down
8 changes: 4 additions & 4 deletions llmkira/extra/plugins/_translate_doc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
from pydantic import ConfigDict

__package__name__ = "llmkira.extra.plugins.translate_file"
__plugin_name__ = "translate_file"
__openapi_version__ = "20231027"
Expand Down Expand Up @@ -44,9 +46,7 @@
class Translate(BaseModel):
language: str
file_id: str

class Config:
extra = "allow"
model_config = ConfigDict(extra="allow")


class TranslateTool(BaseTool):
Expand Down Expand Up @@ -179,7 +179,7 @@ async def run(self,
for i in item.file:
_translate_file.append(i)
try:
translate_arg = Translate.parse_obj(arg)
translate_arg = Translate.model_validate(arg)
except Exception:
raise ValueError("Please specify the following parameters clearly\n file_id=xxx,language=xxx")
_file_obj = [await i.raw_file()
Expand Down
34 changes: 15 additions & 19 deletions llmkira/extra/plugins/alarm/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
from pydantic import field_validator, ConfigDict

__package__name__ = "llmkira.extra.plugins.alarm"
__plugin_name__ = "set_alarm_reminder"
__openapi_version__ = "20231111"
Expand All @@ -10,18 +12,17 @@

import datetime
import re
import time

from loguru import logger
from pydantic import validator, BaseModel
from pydantic import BaseModel

from llmkira.receiver.aps import SCHEDULER
from llmkira.schema import RawMessage
from llmkira.sdk.func_calling import PluginMetadata, BaseTool
from llmkira.sdk.func_calling.schema import FuncPair
from llmkira.task import Task, TaskHeader

from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from llmkira.sdk.schema import TaskBatch
Expand All @@ -44,11 +45,9 @@
class Alarm(BaseModel):
delay: int
content: str
model_config = ConfigDict(extra="allow")

class Config:
extra = "allow"

@validator("delay")
@field_validator("delay")
def delay_validator(cls, v):
if v < 0:
raise ValueError("delay must be greater than 0")
Expand All @@ -58,9 +57,9 @@ def delay_validator(cls, v):
async def send_notify(_platform, _meta, _sender: dict, _receiver: dict, _user, _chat, _content: str):
await Task(queue=_platform).send_task(
task=TaskHeader(
sender=TaskHeader.Location.parse_obj(_sender), # 继承发送者
receiver=TaskHeader.Location.parse_obj(_receiver), # 因为可能有转发,所以可以单配
task_meta=TaskHeader.Meta.parse_obj(_meta),
sender=TaskHeader.Location.model_validate(_sender), # 继承发送者
receiver=TaskHeader.Location.model_validate(_receiver), # 因为可能有转发,所以可以单配
task_meta=TaskHeader.Meta.model_validate(_meta),
message=[
RawMessage(
user_id=_user,
Expand All @@ -79,7 +78,7 @@ class AlarmTool(BaseTool):
silent: bool = False
function: Function = alarm
keywords: list = ["闹钟", "提醒", "定时", "到点", '分钟']
pattern = re.compile(r"(\d+)(分钟|小时|天|周|月|年)后提醒我(.*)")
pattern: Optional[re.Pattern] = re.compile(r"(\d+)(分钟|小时|天|周|月|年)后提醒我(.*)")
require_auth: bool = True

# env_required: list = ["SCHEDULER", "TIMEZONE"]
Expand Down Expand Up @@ -148,7 +147,7 @@ async def run(self,
"""
处理message,返回message
"""
_set = Alarm.parse_obj(arg)
_set = Alarm.model_validate(arg)
_meta = task.task_meta.reply_message(
plugin_name=__plugin_name__,
callback=[
Expand All @@ -163,22 +162,19 @@ async def run(self,
logger.debug("Plugin:set alarm {} minutes later".format(_set.delay))
SCHEDULER.add_job(
func=send_notify,
id=str(time.time()),
id=str(receiver.user_id),
trigger="date",
replace_existing=True,
misfire_grace_time=1000,
run_date=datetime.datetime.now() + datetime.timedelta(minutes=_set.delay),
args=[
task.receiver.platform,
_meta.dict(),
task.sender.dict(), receiver.dict(),
_meta.model_dump(),
task.sender.model_dump(), receiver.model_dump(),
receiver.user_id, receiver.chat_id,
_set.content
]
)
try:
SCHEDULER.start()
except Exception as e:
print(f"[155035]{e}")
await Task(queue=receiver.platform).send_task(
task=TaskHeader(
sender=task.sender, # 继承发送者
Expand Down
8 changes: 4 additions & 4 deletions llmkira/extra/plugins/search.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
from pydantic import ConfigDict

__package__name__ = "llmkira.extra.plugins.search"
__plugin_name__ = "search_in_google"
__openapi_version__ = "20231111"
Expand Down Expand Up @@ -74,9 +76,7 @@ def search_on_duckduckgo(search_sentence: str, key_words: str = None):

class Search(BaseModel):
keywords: str

class Config:
extra = "allow"
model_config = ConfigDict(extra="allow")


class SearchTool(BaseTool):
Expand Down Expand Up @@ -200,7 +200,7 @@ async def run(self,
处理message,返回message
"""

_set = Search.parse_obj(arg)
_set = Search.model_validate(arg)
_search_result = search_on_duckduckgo(_set.keywords)
# META
_meta = task.task_meta.reply_raw(
Expand Down
16 changes: 8 additions & 8 deletions llmkira/extra/plugins/sticker.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# -*- coding: utf-8 -*-
from pydantic import field_validator, ConfigDict

__package__name__ = "llmkira.extra.plugins.sticker"
__plugin_name__ = "convert_to_sticker"
__openapi_version__ = "20231111"
Expand All @@ -15,13 +17,13 @@

from PIL import Image
from loguru import logger
from pydantic import validator, BaseModel
from pydantic import BaseModel

from llmkira.schema import RawMessage
from llmkira.sdk.func_calling import BaseTool
from llmkira.sdk.func_calling.schema import FuncPair, PluginMetadata
from llmkira.task import Task, TaskHeader
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Optional

if TYPE_CHECKING:
from llmkira.sdk.schema import TaskBatch
Expand All @@ -43,11 +45,9 @@
class Sticker(BaseModel):
yes_no: str
comment: str = "done"
model_config = ConfigDict(extra="allow")

class Config:
extra = "allow"

@validator("yes_no")
@field_validator("yes_no")
def delay_validator(cls, v):
if v != "yes":
v = "no"
Expand Down Expand Up @@ -86,7 +86,7 @@ class StickerTool(BaseTool):
"""
function: Function = sticker
keywords: list = ["转换", "贴纸", ".jpg", "图像", '图片']
file_match_required = re.compile(r"(.+).jpg|(.+).png|(.+).jpeg|(.+).gif|(.+).webp|(.+).svg")
file_match_required: Optional[re.Pattern] = re.compile(r"(.+).jpg|(.+).png|(.+).jpeg|(.+).gif|(.+).webp|(.+).svg")

def pre_check(self):
return True
Expand Down Expand Up @@ -158,7 +158,7 @@ async def run(self,
if item.file:
for i in item.file:
_file.append(i)
_set = Sticker.parse_obj(arg)
_set = Sticker.model_validate(arg)
_file_obj = [await i.raw_file() for i in sorted(set(_file), key=_file.index)]
# 去掉None
_file_obj = [item for item in _file_obj if item]
Expand Down
6 changes: 3 additions & 3 deletions llmkira/extra/user/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def __init__(self):
self.client = self.use_collection("user_cost")

async def insert(self, data: UserCost):
await self.client.insert_one(data.dict())
await self.client.insert_one(data.model_dump())
return data

async def read_by_uid(self, uid: str) -> List[UserCost]:
Expand All @@ -50,9 +50,9 @@ async def update(self, uid: str, data: UserConfig, validate: bool = True) -> Use
[("uid", 1)], unique=True
)
try:
await self.client.insert_one(data.dict())
await self.client.insert_one(data.model_dump(mode="json"))
except DuplicateKeyError:
await self.client.update_one({"uid": uid}, {"$set": data.dict()})
await self.client.update_one({"uid": uid}, {"$set": data.model_dump(mode="json")})
return data

async def read_by_uid(self, uid: str) -> Optional[UserConfig]:
Expand Down
46 changes: 22 additions & 24 deletions llmkira/extra/user/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
from enum import Enum
from typing import List, Union, Optional

from pydantic import BaseModel, Field, BaseSettings, validator
from pydantic import field_validator, ConfigDict, BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict

from ...sdk.endpoint import Driver

Expand All @@ -35,17 +36,17 @@ class Cost(BaseModel):
"""
cost_by: str = Field("chat", description="环节")
token_usage: int = Field(0)
token_uuid: str = Field(None, description="Api Key 的 hash")
model_name: str = Field(None, description="Model Name")
token_uuid: Optional[str] = Field(None, description="Api Key 的 hash")
llm_model: Optional[str] = Field(None, description="Model Name")
provide_type: int = Field(None, description="认证模式")

@classmethod
def by_function(cls, function_name: str,
token_usage: int,
token_uuid: str,
model_name: str,
llm_model: str,
):
return cls(cost_by=function_name, token_usage=token_usage, token_uuid=token_uuid, model_name=model_name)
return cls(cost_by=function_name, token_usage=token_usage, token_uuid=token_uuid, model_name=llm_model)

request_id: str = Field(default=None, description="请求 UUID")
uid: str = Field(default=None, description="用户 UID ,注意是平台+用户")
Expand All @@ -70,7 +71,7 @@ def create_from_function(
function_name=cost_by,
token_usage=token_usage,
token_uuid=token_uuid,
model_name=model_name,
llm_model=model_name,
),
cost_time=int(time.time()),
)
Expand All @@ -89,13 +90,11 @@ def create_from_task(
cost_time=int(time.time()),
)

class Config:
extra = "ignore"
allow_mutation = True
arbitrary_types_allowed = True
validate_assignment = True
validate_all = True
validate_on_assignment = True
model_config = ConfigDict(extra="ignore",
arbitrary_types_allowed=True,
validate_assignment=True,
validate_default=True
)


class UserConfig(BaseSettings):
Expand Down Expand Up @@ -133,7 +132,7 @@ def set_proxy_public(self, token: str, provider: str):
self.token = token
return self

@validator("provider")
@field_validator("provider")
def upper_provider(cls, v):
if v:
return v.upper()
Expand All @@ -146,12 +145,12 @@ class PluginConfig(BaseModel):
def default(cls):
return cls()

def block(self, plugin_name: str) -> "PluginConfig":
def block(self, plugin_name: str) -> "UserConfig.PluginConfig":
if plugin_name not in self.block_list:
self.block_list.append(plugin_name)
return self

def unblock(self, plugin_name: str) -> "PluginConfig":
def unblock(self, plugin_name: str) -> "UserConfig.PluginConfig":
if plugin_name in self.block_list:
self.block_list.remove(plugin_name)
return self
Expand All @@ -162,16 +161,15 @@ def unblock(self, plugin_name: str) -> "PluginConfig":
plugin_subs: PluginConfig = Field(default_factory=PluginConfig.default, description="插件订阅")
llm_driver: LlmConfig = Field(default_factory=LlmConfig.default, description="驱动")

@validator("uid")
@field_validator("uid")
def check_user_id(cls, v):
if v:
return str(v)
return v

class Config:
extra = "ignore"
allow_mutation = True
arbitrary_types_allowed = True
validate_assignment = True
validate_all = True
validate_on_assignment = True
model_config = SettingsConfigDict(extra="ignore",
frozen=True,
arbitrary_types_allowed=True,
validate_assignment=True,
validate_default=True
)
Loading

0 comments on commit 27a1d14

Please sign in to comment.