From b8ea363a898d8ee5c5d9647aaa37202b69f83d87 Mon Sep 17 00:00:00 2001 From: Alex-TG001 <60740185+Alex-TG001@users.noreply.github.com> Date: Wed, 14 Aug 2024 17:08:20 +0800 Subject: [PATCH 1/2] Create main.ipynb --- .../dialogue multi-tag generation/main.ipynb | 678 ++++++++++++++++++ 1 file changed, 678 insertions(+) create mode 100644 cookbook/awesome_demo/dialogue multi-tag generation/main.ipynb diff --git a/cookbook/awesome_demo/dialogue multi-tag generation/main.ipynb b/cookbook/awesome_demo/dialogue multi-tag generation/main.ipynb new file mode 100644 index 00000000..34551acf --- /dev/null +++ b/cookbook/awesome_demo/dialogue multi-tag generation/main.ipynb @@ -0,0 +1,678 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 客服对话多标签生成\n", + "\n", + "在客服对话场景中,可以通过大模型分析用户与客服之间的对话信息,准确识别用户的意图和对应原因,生成对应标签为后续回复和营销策略服务。比如,用户因地址填写错误与客服沟通,说希望能够取消订单,则取消订单是意图,地址填写错误是原因。最初我们选择小模型进行多标签生成,在初期使用中展现出一定的效果,能够在较短时间内进行部署并提供基础的标签生成功能。然而,随着业务需求的日益复杂,现有的小模型在多标签生成上存在一些明显的问题和挑战:\n", + "\n", + "\n", + "* 标签准确率不高:小模型的标签准确率通常保持在接近80%左右,无法满足业务进一步提高准确率的期望。业务需求日益复杂,用户提出的问题多样化,小模型的识别能力有限,导致部分标签无法准确标识。\n", + "* 标注数据需求量大:训练一个有效的小模型,每个标签至少需要300个标注数据,人工成本高昂。尤其是在业务需求不断变化的情况下,标注工作量进一步增加。\n", + "* 对标签体系的依赖性强:小模型对标签体系还会有较强的依赖,一旦业务标签体系发生较大变化,例如三层标签扩展到四层标签或大规模调整标签结构,标注和训练工作需要大规模重复建设。\n", + "\n", + "\n", + "针对上述问题,我们提出了使用ERNIE Tiny大模型进行微调的解决方案。ERNIE Tiny大模型在语言理解能力上更强,能够在较少标注数据的情况下,达到或超过现有小模型的准确率要求。通过微调训练,ERNIE Tiny大模型能够更好地适应业务需求的变化,提高多标签生成的准确性和效率,从而更好地支持客服对话场景的应用。" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "from qianfan import ChatCompletion\n", + "from qianfan.dataset import Dataset\n", + "from qianfan.common import Prompt\n", + "from qianfan.trainer import LLMFinetune\n", + "from qianfan.trainer.consts import PeftType\n", + "from qianfan.trainer.configs import TrainConfig\n", + "import os\n", + "from qianfan.dataset import Dataset\n", + "from qianfan.dataset.data_source import BosDataSource\n", + "from qianfan.dataset.data_source.base import FormatType" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "os.environ[\"QIANFAN_ACCESS_KEY\"] = \"your_access_key\"\n", + "os.environ[\"QIANFAN_SECRET_KEY\"] = \"your_secret_key\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 1. 基座模型效果示例" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "首先,我们选择了ERNIE-Tiny-8K模型作为本次实验的基座模型。" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "chat = ChatCompletion(model=\"ERNIE-Tiny-8K\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 例一" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "根据对话内容,最有可能的客户意图和对应的原因标签是:\n", + "“客户意图”:订单取消\n", + "“原因标签”:解决问题\n", + "\n", + "解释:客户李先生表示地址写错了需要取消订单,王琳K表示可以帮他取消订单并提供了相应的帮助。因此,可以判断客户的意图是订单取消。\n" + ] + } + ], + "source": [ + "target ={ \n", + " \"conversation\": (\n", + " \"王琳K:欢迎光临DianCan披萨,为了给您提供更加优质的服务,请问您有什么具体的问题或需要帮助吗?\"\n", + " \"客户李先生:我刚刚下了一个订单,但是地址写错了,能帮我取消吗?\"\n", + " \"王琳K:非常抱歉给您带来困扰,我可以帮您取消订单。为了确认您的身份,需要您提供订单号或者下单时使用的电话号码,可以吗?\"\n", + " \"客户李先生:我的订单号是DC123456789,电话号码是138****1234。\"\n", + " \"王琳K:非常感谢您提供的信息,我已经找到了您的订单。现在我将为您取消该订单,请稍等片刻。\"\n", + " \"客户李先生:好的,谢谢。\"\n", + " \"王琳K:您的订单已经成功取消。如有其他问题,请随时联系我们。感谢您的理解和支持。\"\n", + " \"客户李先生:非常感谢你们的帮助,我会重新下单的。\"\n", + " \"王琳K:非常高兴能够帮助您解决问题。祝您用餐愉快!如有其他问题,请随时联系我们。\"\n", + ")\n", + "}\n", + "prompt = Prompt(\"\"\"你是一个对话意图识别并打标签的机器人,根据下面的已知信息,打上标签。\n", + " 请使用以下格式输出:{\"意图\": xxx\n", + " \"原因\": xxx}\n", + "请根据以下会话的内容,精准判断最有可能的客户意图以及对应的原因标签,意图和原因标签必须严格控制在给定的范围之内。:\n", + "\n", + "{conversation}\n", + "\"\"\")\n", + "\n", + "resp = chat.do(messages=[{\"role\": \"user\", \"content\": prompt.render(**target)[0]}])\n", + "\n", + "print(resp[\"result\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 例二" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "```json\n", + "{\n", + "\"意图\": \"客户服务\",\n", + "\"原因\": \"顾客表达想要取消订单,处理地址错误的情况,机器人回应询问订单号和手机号码后四位以确认身份,并提供了处理订单取消及退款的具体措施\"\n", + "}\n", + "```\n" + ] + } + ], + "source": [ + "target ={ \n", + " \"conversation\": (\n", + " \"DianCan 自助点餐机器人:您好,欢迎光临DianCan披萨,有什么可以为您服务的吗?\"\n", + " \"顾客:你好,我刚刚下了一个订单,但是我发现我填写的地址是错的。\"\n", + " \"DianCan 自助点餐机器人:非常抱歉给您带来困扰。请问您是否希望取消订单并重新下单呢?\"\n", + " \"顾客:是的,我想取消订单。\"\n", + " \"DianCan 自助点餐机器人:好的,请告诉我您的订单号,我会尽快帮您处理。\"\n", + " \"顾客:我的订单号是XXXX。\"\n", + " \"DianCan 自助点餐机器人:好的,已经为您查询到了订单。为了确认您的身份,请问您能提供下单时使用的手机号码后四位吗?\"\n", + " \"顾客:手机号码后四位是XXXX。\"\n", + " \"DianCan 自助点餐机器人:非常感谢,已经确认您的身份。我们现在就为您取消订单,并会尽快处理退款。退款将在3-7个工作日内原路返回至您的支付账户。请您注意查收。\"\n", + " \"顾客:好的,非常感谢你们的帮助。\"\n", + " \"DianCan 自助点餐机器人:不客气,如果您还有其他问题或需要进一步的帮助,请随时与我们联系。祝您用餐愉快!\"\n", + ")\n", + "}\n", + "\n", + "resp = chat.do(messages=[{\"role\": \"user\", \"content\": prompt.render(**target)[0]}])\n", + "\n", + "print(resp[\"result\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "从上述两组例中,我们可以总结出以下问题:\n", + "\n", + "\n", + "* 问题一:微调前模型的输出可能并不能完全遵循指定格式进行输出。\n", + "* 问题二:基座模型的输出无法精准识别客户意图和原因。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# 2. 模型精调数据准备" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.1 构造“意图-原因”标签集" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "| 意见 | 原因 | 数据编号 |\n", + "|----------------------------|-------------------------------|----------|\n", + "| 如何提交评价 | 未收到评价请求 | 1 |\n", + "| 你们有某种餐品吗 | 餐品缺货、短期或长缺 | 1 |\n", + "| 某餐需要做成不辣 | 顾客特殊需求 | 1 |\n", + "| 在餐厅怎么买东西怎么回 | 找回遗失物品 | 1 |\n", + "| 订单什么时候能做好 | 餐品制作时间 | 1 |\n", + "| 如何访问我的订单历史记录 | 订单历史订单详情 | 1 |\n", + "| 取消订单 | 取消订单_具体原因 | 1 |\n", + "| 为什么我的优惠券没见了 | 优惠券未到期 | 1 |\n", + "| 取消订单 | 地址填写错误 | 1 |\n", + "| 我想要把单品或套餐添加/删除 | j1或添加单品_质量定制 | 2 |\n", + "| 餐品不对 | 源错误 | 1 |\n", + "| 餐品配送太少 | 食物波动 | 1 |\n", + "| 我在餐厅可以参与活动吗 | 活动相关咨询 | 1 |\n", + "| 有关于推荐的商品 | 需要推荐咨询 | 1 |\n", + "| 我在餐厅开发票 | 开发票 | 2 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 2.2 生成对话数据" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "从”意图-原因“生成客服对话的Prompt:" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "conversation_prompt =\"\"\"你是一个利用【意图-原因】生成对话的机器人,请你仔细观察下面的输入输出,发挥你的想象,根据【输入】提供的已知信息和要求,生成客服对话。具体要求如下:\n", + "1.【输入】和【输出】的格式与示例相同。\n", + "2.【输出】为生成的对话。\n", + "3.你生成的是DianCan披萨公司的客服对话,请你刻意回避百胜公司、肯德基餐饮相关的产品和名词。\n", + "遵循以上准则,请你根据【输入】创造一个新的客服对话,示例如下:\n", + "\n", + "【输入】\n", + "###意图-原因\n", + "\"{\"意图\": \"某产品你们有吗\"\n", + "\"原因\": \"产品是否有_餐厅断货_临时or永久\"}\"\n", + "【输出】\n", + "### 对话内容\n", + "\"客服:正在为您转接人工服务中,目前人工繁忙,如需继续等待请输入:继续\n", + "客服:欢迎进入人工客服通道,56959很高兴为您服务对话过程中以及完成后,您会收到评价提醒,希望您能对我个人本次服务做个评价,您的反馈和建议也是我努力的方向哦,感谢~\n", + "顾客:继续\n", + "顾客:你好,鸡腿饭现在是下架了嘛\n", + "客服:DianCan客服中心,很高兴为您服务,我先查看一下您反馈的问题哦~\n", + "顾客:附近每家店子都没有[嚎哭]\n", + "客服:没有看到说明售罄,暂时断货,没有这个餐点,建议客官过段时间在购买查看的,不好意思。\n", + "客服:以您在我们官网看到的为准,有就是有,没有就是没有的呢。\n", + "顾客:那可以查一下附近哪家店有吗\n", + "顾客:武汉洪山区哪家店有\n", + "顾客:我看了好多家都没有[嚎哭]\n", + "客服:小二这边是客服中心的,不是某家餐厅,不是很清楚每个门店的具体情况,非常抱歉。\n", + "客服:或者您可以通过DianCan微信公众号-自助服务-点击入群,可加入附近DianCan餐厅的社群哦~了解具体信息。\n", + "客服:亲亲,您还在线吗?我还在快马加鞭处理中,如果您还有问题,也及时回应哦~\n", + "客服:\"\n", + "---------------------------------\n", + "【输入】 \n", + "###意图-原因\n", + "\"{\"意图\": \"你们有某种餐品吗\"\n", + "\"原因\": \"餐品断货_短期或长期\"}\"\n", + "【输出】:\"\"\"\n" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "客服:您好,目前正在为您查询,关于您提到的餐品,我们暂时没有收到关于断货或下架的信息。建议您可以通过我们的官网或者微信公众号了解最新产品信息。如果您需要其他帮助或有其他问题,请随时告知,我们会尽快为您处理。\n", + "\n", + "如果您对餐品的供应情况有疑问或需要了解更多信息,建议您通过我们的官方渠道查询最新消息。如果您还有其他问题或需求,请随时告知,我们会竭诚为您服务。\n" + ] + } + ], + "source": [ + "resp = chat.do(messages=[{\"role\": \"user\", \"content\": conversation_prompt}])\n", + "print(resp[\"result\"])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3. SFT调优示例" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3.1 数据集导入" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "在完成上述的数据集准备工作后,我们可以开始进行模型微调训练。" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "首先从平台中获取微调用的训练集" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO][2024-08-12 16:47:47.213] dataset.py:430 [t:8570851136]: no data source was provided, construct\n", + "[INFO][2024-08-12 16:47:47.214] dataset.py:282 [t:8570851136]: construct a qianfan data source from existed id: ds-scm8g98a7pv3zzf3, with args: {'format': }\n", + "[INFO][2024-08-12 16:47:47.989] dataset_utils.py:317 [t:8570851136]: list qianfan dataset data by 0\n", + "[INFO][2024-08-12 16:47:48.404] dataset_utils.py:339 [t:8570851136]: received dataset list from qianfan dataset\n", + "[INFO][2024-08-12 16:47:48.405] dataset_utils.py:347 [t:8570851136]: retrieve single entity from https://easydata.bj.bcebos.com/_system_/dataset/ds-scm8g98a7pv3zzf3/texts/data/raw_aca7da8ef71c956315d9a7dc3874a4d5a65280bb382263c39aa16482de1b666e_91dc9885731b405bb32a5d5734c4dd5f?authorization=bce-auth-v1%2F50c8bb753dcb4e1d8646bb1ffefd3503%2F2024-08-12T08%3A47%3A48Z%2F7200%2Fhost%2F4042ce5b329d3026629877a2108adae8293bc615d5b4cc9c76948f9a2a917147 in try 0\n", + "[INFO][2024-08-12 16:47:48.631] dataset_utils.py:361 [t:8570851136]: retrieve single entity from https://easydata.bj.bcebos.com/_system_/dataset/ds-scm8g98a7pv3zzf3/texts/data/raw_aca7da8ef71c956315d9a7dc3874a4d5a65280bb382263c39aa16482de1b666e_91dc9885731b405bb32a5d5734c4dd5f?authorization=bce-auth-v1%2F50c8bb753dcb4e1d8646bb1ffefd3503%2F2024-08-12T08%3A47%3A48Z%2F7200%2Fhost%2F4042ce5b329d3026629877a2108adae8293bc615d5b4cc9c76948f9a2a917147 succeeded, with content: [{\"prompt\": \"假设你有一套客户意图分类以及该分类下属的原因标签。请根据给定的客服对话内容,判断最有可能的客户意图以及对应的原因标签,意图和原因标签需要严格控制给定的范围之内;一个意图可能对应多个原因,但一个原因只会对应一个意图;如果均不匹配则回答无明确客户意图;回答请使用json的格式,示例:'{\\\"意图\\\": \\\"xxx\\\",\\\"原因”: \\\"xxx\\\"}'\\n### 下面是客户意图的分类\\n1.客户意图:如何提交评价;原因标签:未收到评价邀请\\n2.客户意图:你们有某种餐品吗;原因标签:餐品缺货_短期或长期\\n3.客户意图:某餐品需要做成不辣;原因标签:顾客特需服务\\n4.客户意图:在餐厅丢失了物品怎么寻回;原因标签:找回遗失物品\\n5.客户意图:订单什么时候能做好;原因标签:餐品制作时间\\n6.客户意图:如何访问我的订单历史记录;原因标签:历史订单查询\\n7.客户意图:取消订单;原因标签:取消订单_无具体理由\\n8.客户意图:为什么我的优惠券不见了;原因标签:优惠券未到账\\n9.客户意图:取消订单;原因标签:地址填写错误\\n10.客户意图:我想要加番茄酱或者不加番茄酱;原因标签:加or不加番茄酱_顾客定制\\n11.客户意图:餐品不对;原因标签:漏餐错餐\\n12.客户意图:餐厅电话是多少;原因标签:食物变质\\n13.客户意图:我在哪里可以参加活动;原因标签:活动地点咨询\\n14.客户意图:有没有推荐的产品;原因标签:需要推荐餐品\\n15.客户意图:客户要开发票;原因标签:开发票\\n\\n###对话内容\\n李星辰DY 2023年07月19日 10:23:48\\n您好,欢迎光临DianCan披萨,有什么可以为您服务的吗?\\n用户673210 2023年07月19日 10:24:12\\n我想开一下发票\\n李星辰DY 2023年07月19日 10:24:35\\n当然可以,请您提供一下订单号和开票信息,我们会尽快为您处理。\\n用户673210 2023年07月19日 10:25:01\\n订单号是DC230715001,开票信息是公司名称:XX科技有限公司,税号:9132XXXXXXXXX\\n李星辰DY 2023年07月19日 10:25:38\\n好的,已经收到您的订单号和开票信息,我们会尽快为您开具发票并发送到您的邮箱。请问您的邮箱地址是什么?\\n用户673210 2023年07月19日 10:26:05\\n我的邮箱是[example@example.com](mailto:example@example.com)\\n李星辰DY 2023年07月19日 10:26:32\\n非常感谢,我们已经记录下了您的邮箱地址。发票将在24小时内发送到您的邮箱,请注意查收。\\n用户673210 2023年07月19日 10:27:00\\n好的,谢谢!\\n李星辰DY 2023年07月19日 10:27:25\\n不客气,如果您有任何其他问题或需要进一步的帮助,请随时联系我们。祝您用餐愉快!\\n\\n###输出\", \"response\": [[\"{'意图': '客户要开发票', '原因': '开发票'}\"]]}]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[{'entity_id': 'aca7da8ef71c956315d9a7dc3874a4d5a65280bb382263c39aa16482de1b666e_91dc9885731b405bb32a5d5734c4dd5f', 'entity_content': '[{\"prompt\": \"假设你有一套客户意图分类以及该分类下属的原因标签。请根据给定的客服对话内容,判断最有可能的客户意图以及对应的原因标签,意图和原因标签需要严格控制给定的范围之内;一个意图可能对应多个原因,但一个原因只会对应一个意图;如果均不匹配则回答无明确客户意图;回答请使用json的格式,示例:\\'{\\\\\"意图\\\\\": \\\\\"xxx\\\\\",\\\\\"原因”: \\\\\"xxx\\\\\"}\\'\\\\n### 下面是客户意图的分类\\\\n1.客户意图:如何提交评价;原因标签:未收到评价邀请\\\\n2.客户意图:你们有某种餐品吗;原因标签:餐品缺货_短期或长期\\\\n3.客户意图:某餐品需要做成不辣;原因标签:顾客特需服务\\\\n4.客户意图:在餐厅丢失了物品怎么寻回;原因标签:找回遗失物品\\\\n5.客户意图:订单什么时候能做好;原因标签:餐品制作时间\\\\n6.客户意图:如何访问我的订单历史记录;原因标签:历史订单查询\\\\n7.客户意图:取消订单;原因标签:取消订单_无具体理由\\\\n8.客户意图:为什么我的优惠券不见了;原因标签:优惠券未到账\\\\n9.客户意图:取消订单;原因标签:地址填写错误\\\\n10.客户意图:我想要加番茄酱或者不加番茄酱;原因标签:加or不加番茄酱_顾客定制\\\\n11.客户意图:餐品不对;原因标签:漏餐错餐\\\\n12.客户意图:餐厅电话是多少;原因标签:食物变质\\\\n13.客户意图:我在哪里可以参加活动;原因标签:活动地点咨询\\\\n14.客户意图:有没有推荐的产品;原因标签:需要推荐餐品\\\\n15.客户意图:客户要开发票;原因标签:开发票\\\\n\\\\n###对话内容\\\\n李星辰DY 2023年07月19日 10:23:48\\\\n您好,欢迎光临DianCan披萨,有什么可以为您服务的吗?\\\\n用户673210 2023年07月19日 10:24:12\\\\n我想开一下发票\\\\n李星辰DY 2023年07月19日 10:24:35\\\\n当然可以,请您提供一下订单号和开票信息,我们会尽快为您处理。\\\\n用户673210 2023年07月19日 10:25:01\\\\n订单号是DC230715001,开票信息是公司名称:XX科技有限公司,税号:9132XXXXXXXXX\\\\n李星辰DY 2023年07月19日 10:25:38\\\\n好的,已经收到您的订单号和开票信息,我们会尽快为您开具发票并发送到您的邮箱。请问您的邮箱地址是什么?\\\\n用户673210 2023年07月19日 10:26:05\\\\n我的邮箱是[example@example.com](mailto:example@example.com)\\\\n李星辰DY 2023年07月19日 10:26:32\\\\n非常感谢,我们已经记录下了您的邮箱地址。发票将在24小时内发送到您的邮箱,请注意查收。\\\\n用户673210 2023年07月19日 10:27:00\\\\n好的,谢谢!\\\\n李星辰DY 2023年07月19日 10:27:25\\\\n不客气,如果您有任何其他问题或需要进一步的帮助,请随时联系我们。祝您用餐愉快!\\\\n\\\\n###输出\", \"response\": [[\"{\\'意图\\': \\'客户要开发票\\', \\'原因\\': \\'开发票\\'}\"]]}]'}]\n" + ] + } + ], + "source": [ + "ds = Dataset.load(qianfan_dataset_id = \"ds-scm8g98a7pv3zzf3\", format = FormatType.Jsonl)\n", + "print(ds[0])" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 3.2 微调训练与测试\n", + "\n", + "拿到一个训练场景或者任务后,往往比较难判断参数应该如何调整。一般使用默认的参数值进行训练即可,平台中的默认参数是多次实验的经验结晶。 接下来介绍参数配置中有两个较为关键的参数:\n", + "\n", + "* 迭代轮次(Epoch): 控制训练过程中的迭代轮数。轮数增加代表会使用训练集对模型训练一次。\n", + "\n", + "* 学习率(Learning Rate): 是在梯度下降的过程中更新权重时的超参数,过高会导致模型难以收敛,过低则会导致模型收敛速度过慢,平台已给出默认推荐值,也可根据经验调整。\n", + "\n", + "* 序列长度:如果对话数据的长度较短,建议选择短的序列长度,可以提升训练的速度。\n", + "\n", + "本次也针对Epoch和Learning Rate进行简要的调参实验,详细实验结果可以看效果评估数据。\n", + "\n", + "如果您是模型训练的专家,千帆也提供了训练更多的高级参数供您选择。这里也建议您初期调参时步长可以设定稍大些,因为较小的超参变动对模型效果的影响小,会被随机波动掩盖。\n", + "\n", + "针对我们的任务,此处设计了六组sft实验,参数和训练方法配置如下。\n", + "\n", + "实验数据如下:\n", + "| | 实验1 | 实验2 | 实验3 | 实验4 | 实验5 | 实验6 |\n", + "|-|-|-|-|-|-|-|\n", + "| 精调方法 | LoRA | LoRA | LoRA | 全量 | 全量 | 全量 |\n", + "| Epoch | 6 | 3 | 3 | 3 | 6 | 3 |\n", + "| Learning Rate | 3e-4 | 3e-4 | 6e-4 | 3e-5 | 3e-5 | 4e-5 |" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "创建trainer任务" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [], + "source": [ + "trainer = LLMFinetune(\n", + " name = \"dialogue-multi-tag\",\n", + " train_type=\"ERNIE-Tiny-8K\",\n", + " train_config=TrainConfig(\n", + " epoch=1,\n", + " learning_rate=1e-5,\n", + " peft_type=PeftType.ALL,\n", + " ),\n", + " dataset=ds\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "启动训练任务" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO][2024-08-12 16:51:44.361] base.py:226 [t:8570851136]: trainer subprocess started, pid: 11416\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[None]\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO][2024-08-12 16:51:44.368] base.py:202 [t:8570851136]: check running log in .qianfan_exec_cache/ZdTRr7iw/2024-08-12.log\n" + ] + } + ], + "source": [ + "trainer.start()\n", + "print(trainer.result)" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO][2024-08-12 17:23:18.156] dataset.py:430 [t:8570851136]: no data source was provided, construct\n", + "[INFO][2024-08-12 17:23:18.157] dataset.py:282 [t:8570851136]: construct a qianfan data source from existed id: ds-scm8g98a7pv3zzf3, with args: {}\n" + ] + }, + { + "data": { + "text/plain": [ + "{'datasets': {'versions': [{'versionId': 'ds-scm8g98a7pv3zzf3'}],\n", + " 'sourceType': 'Platform',\n", + " 'splitRatio': 20},\n", + " 'task_id': 'task-85tcmxg0try3',\n", + " 'job_id': 'job-xhk1gtuvvdbh',\n", + " 'metrics': {'BLEU-4': '99.58%',\n", + " 'ROUGE-1': '99.62%',\n", + " 'ROUGE-2': '99.60%',\n", + " 'ROUGE-L': '99.75%',\n", + " 'EDIT-DISTANCE': '0.11',\n", + " 'EMBEDDING-DISTANCE': '0.00'},\n", + " 'checkpoints': [],\n", + " 'model_set_id': 'am-mgf6icsebsa4',\n", + " 'model_id': 'amv-8t6qf24m4xcb'}" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "trainer.output" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "from qianfan.model import Model\n", + "\n", + "# 从`version_id`构造模型:\n", + "m = Model(id='amv-8t6qf24m4xcb')\n", + "m.auto_complete_info()" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "[INFO][2024-08-13 14:37:24.325] dataset.py:430 [t:8570851136]: no data source was provided, construct\n", + "[INFO][2024-08-13 14:37:24.326] dataset.py:282 [t:8570851136]: construct a qianfan data source from existed id: ds-n1dg1czx3ciqrakr, with args: {}\n" + ] + } + ], + "source": [ + "eval_ds = Dataset.load(qianfan_dataset_id =\"ds-n1dg1czx3ciqrakr\")" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "from qianfan.evaluation.evaluator import QianfanRefereeEvaluator, QianfanRuleEvaluator\n", + "from qianfan.evaluation.consts import QianfanRefereeEvaluatorDefaultMetrics, QianfanRefereeEvaluatorDefaultSteps, QianfanRefereeEvaluatorDefaultMaxScore\n", + "\n", + "your_app_id = 105835560\n", + "\n", + "qianfan_evaluators = [\n", + " QianfanRefereeEvaluator(\n", + " app_id=your_app_id,\n", + " prompt_metrics=QianfanRefereeEvaluatorDefaultMetrics,\n", + " prompt_steps=\"\"\"你是一个好助手。请你为下面问题的回答打分\n", + " 问题如下: {src}\n", + " 标准答案如下:{tgt}\n", + " 回答如下:{prediction}\n", + " 请你遵照以下的评分步骤:{steps}\n", + " 根据答案的综合水平给出0到3的评分。如果答案存在明显的不合理之处,则应给出一个较低的评分。如果答案符合以上要求并且与参考答案含义相似,则应给出一个较高的评分。\n", + " 你的回答模版如下:\n", + " 评分: 此处只能回答整数评分\n", + " 原因: 此处只能回答评分原因\"\"\",\n", + " prompt_max_score=3,\n", + " ),\n", + " QianfanRuleEvaluator(using_accuracy=True, using_similarity=True),\n", + "]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from qianfan.evaluation.local_evaluator import LocalJudgeEvaluator\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "from typing import Any, Dict, List, Union, Optional\n", + "import qianfan\n", + "from qianfan import ChatCompletion\n", + "from qianfan.common import Prompt\n", + "from qianfan.utils.pydantic import Field\n", + "from qianfan.evaluation.evaluator import LocalEvaluator\n", + "from qianfan.evaluation.consts import (\n", + " QianfanRefereeEvaluatorPromptTemplate,\n", + " QianfanRefereeEvaluatorDefaultMaxScore,\n", + " QianfanRefereeEvaluatorDefaultMetrics,\n", + " QianfanRefereeEvaluatorDefaultSteps,\n", + ")\n", + "\n", + "\n", + "class LocalJudgeEvaluator(LocalEvaluator):\n", + "\n", + " model: Optional[ChatCompletion] = Field(default=None, description=\"model object\")\n", + " metric_name: str = Field(default=\"\", description=\"metric name for evaluation\")\n", + " evaluation_prompt: Prompt = Field(default=Prompt(QianfanRefereeEvaluatorPromptTemplate), description=\"concrete evaluation prompt string\")\n", + " prompt_metrics: str = Field(default=QianfanRefereeEvaluatorDefaultMetrics, description=\"evaluation metrics\")\n", + " prompt_steps: str = Field(default=QianfanRefereeEvaluatorDefaultSteps, description=\"evaluation steps\")\n", + " prompt_max_score: int = Field(default=QianfanRefereeEvaluatorDefaultMaxScore, description=\"max score for evaluation\")\n", + " \n", + " class Config:\n", + " arbitrary_types_allowed = True\n", + " def evaluate(\n", + " self, input: Union[str, List[Dict[str, Any]]], reference: str, output: str\n", + " ) -> Dict[str, Any]:\n", + " \"\"\"\n", + " 使用模型进行本地评估\n", + " :param input: 给定的prompt,\n", + " evaluateManager.eval()的is_chat参数为true时,\n", + " input为对话记录,否则为单字符串prompt\n", + " :param reference: 用户给定的标准答案\n", + " :param output: 大模型生成的结果,eval中由service生成,eval_only中由用户给定\n", + "\n", + " :return: 评估结果\n", + " \"\"\"\n", + " if isinstance(input, list):\n", + " if not isinstance(self.model, ChatCompletion):\n", + " raise ValueError(f\"model is not an instance of ChatCompletion\")\n", + " if len(input)!=1: # 只考虑ChatCompletion单文本输入的情况\n", + " raise ValueError(f\"chat history is not single text\")\n", + " input_content = input[0].get('content','')\n", + " \n", + " # 生成评价模板\n", + " prompt_text, _ = self.evaluation_prompt.render(\n", + " metric_name=self.metric_name,\n", + " criteria=self.prompt_metrics,\n", + " steps=self.prompt_steps,\n", + " max_score=self.prompt_max_score,\n", + " prompt=input_content,\n", + " response=reference,\n", + " )\n", + " \n", + " # 调用模型获得评分\n", + " msg = qianfan.Messages()\n", + " msg.append(prompt_text)\n", + " \n", + " resp = self.model.do(\n", + " messages=msg,\n", + " temperature=0.1,\n", + " top_p=1,\n", + " )\n", + " \n", + " # print(f'{self.metric_name}|{input[0]}|{output}|{resp[\"result\"].strip()}')\n", + " return {self.metric_name: resp[\"result\"].strip()}\n", + " else:\n", + " raise ValueError(f\"input in {type(input)} not supported\")" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "import qianfan\n", + "from qianfan.model import Service\n", + "from qianfan.evaluation import EvaluationManager\n", + "chat_comp = qianfan.ChatCompletion(model=\"ERNIE-4.0-8K\") # 实例化用于裁判的模型" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "em = EvaluationManager(local_evaluators=local_evaluators)\n", + "result = em.eval(\n", + " [m], ds,\n", + ")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.7" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} From edad1642d09761353b6ed5b7eae2fe7be7eeef5c Mon Sep 17 00:00:00 2001 From: NuODaniel Date: Sat, 17 Aug 2024 15:56:33 +0800 Subject: [PATCH 2/2] feat: support bearer token (#745) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: add support for qianfan console api calls (#732) * doc: 增加 go 自定义错误码文档 (#733) * support custom retry codes * add custom retry code example * release go v0.0.11 (#734) * fix: api v2 & bearer token * fix: api v2 * fix: lint and update version * Update evaluation_manager.py (#736) * Update evaluation_manager.py * Update consts.py * Afs datasource (#737) * 添加 AFSDatasource * Debug * feat: 添加数据集 V2 API (#739) * 添加数据集 V2 API * 追加注释 * fix text2img model list (#741) * Fix unclosed aiohttp session whlie exception raised (#743) * Fix unclosed aiohttp session whlie exception raised * format and lint * Update http_client.py --------- Co-authored-by: NuODaniel * fix: merge --------- Co-authored-by: Azure99 Co-authored-by: Liu Jun Co-authored-by: AlexT <60740185+Alex-TG001@users.noreply.github.com> Co-authored-by: Dobiichi-Origami <56953648+Dobiichi-Origami@users.noreply.github.com> Co-authored-by: Guocheng --- docs/inference.md | 34 +- go/README.md | 18 + go/qianfan/model_endpoint_retriever.go | 2 +- go/qianfan/text2img.go | 2 +- go/qianfan/version.go | 2 +- java/README.md | 6 +- java/example/pom.xml | 2 +- .../java/com/baidubce/ConsoleExample.java | 64 +++ .../com/baidubce/SystemMemoryExample.java | 113 ++++ java/pom.xml | 2 +- .../java/com/baidubce/qianfan/Qianfan.java | 17 + .../com/baidubce/qianfan/QianfanClient.java | 41 +- .../baidubce/qianfan/core/StreamIterator.java | 1 - .../qianfan/core/builder/ConsoleBuilder.java | 91 ++++ .../qianfan/model/console/ConsoleRequest.java | 62 +++ .../model/console/ConsoleResponse.java | 12 +- .../com/baidubce/qianfan/util/CollUtils.java | 63 +++ .../java/com/baidubce/qianfan/util/Json.java | 8 + .../qianfan/util/ParameterizedTypeImpl.java | 45 ++ python/pyproject.toml | 2 +- python/qianfan/__init__.py | 2 +- python/qianfan/config.py | 6 + python/qianfan/consts.py | 19 +- python/qianfan/dataset/__init__.py | 2 + python/qianfan/dataset/consts.py | 9 + .../qianfan/dataset/data_source/__init__.py | 2 + python/qianfan/dataset/data_source/afs.py | 272 ++++++++++ python/qianfan/dataset/data_source/bos.py | 47 +- python/qianfan/dataset/data_source/utils.py | 48 ++ python/qianfan/dataset/dataset.py | 27 + python/qianfan/errors.py | 6 + python/qianfan/evaluation/consts.py | 62 ++- .../qianfan/evaluation/evaluation_manager.py | 1 + python/qianfan/resources/__init__.py | 2 + python/qianfan/resources/auth/oauth.py | 298 ++++++++--- .../batch_inference/helper/helper.py | 13 + python/qianfan/resources/console/consts.py | 15 + python/qianfan/resources/console/data.py | 494 ++++++++++++++++++ python/qianfan/resources/console/iam.py | 63 +++ python/qianfan/resources/http_client.py | 4 +- python/qianfan/resources/llm/base.py | 49 +- .../qianfan/resources/llm/chat_completion.py | 25 +- python/qianfan/resources/requestor/base.py | 3 +- .../resources/requestor/console_requestor.py | 23 +- .../resources/requestor/openapi_requestor.py | 86 ++- .../qianfan/tests/chat_completion_v2_test.py | 48 ++ python/qianfan/tests/utils/mock_server.py | 57 +- python/qianfan/tests/utils/utils.py | 1 + python/qianfan/utils/fake_pyarrow/__init__.py | 6 + 49 files changed, 2085 insertions(+), 192 deletions(-) create mode 100644 java/example/src/main/java/com/baidubce/ConsoleExample.java create mode 100644 java/example/src/main/java/com/baidubce/SystemMemoryExample.java create mode 100644 java/src/main/java/com/baidubce/qianfan/core/builder/ConsoleBuilder.java create mode 100644 java/src/main/java/com/baidubce/qianfan/model/console/ConsoleRequest.java create mode 100644 java/src/main/java/com/baidubce/qianfan/util/CollUtils.java create mode 100644 java/src/main/java/com/baidubce/qianfan/util/ParameterizedTypeImpl.java create mode 100644 python/qianfan/dataset/data_source/afs.py create mode 100644 python/qianfan/resources/console/iam.py diff --git a/docs/inference.md b/docs/inference.md index 3224a301..66da0367 100644 --- a/docs/inference.md +++ b/docs/inference.md @@ -139,14 +139,42 @@ async for r in resp: - `json_body`:请求体 - `retry_config`:请求使用的重试信息 -##### V2 版本 +#### V2 版本 -千帆平台推出了 V2 版本的推理 API,SDK 也支持对 V2 版本的 API 进行调用,只需要创建对象时传入 `version="2"` 即可,其余使用方法与上述一致,差异点主要在于字段名称,具体字段名请参考 API 文档 +千帆平台推出了 V2 版本的推理 API,SDK 也支持对 V2 版本的 API 进行调用: + +##### V2 鉴权 + +API v2 采用Bearer Token的鉴权方式:可以通过access_key 和 secret_key 获取。因此可以选择以下两种方式设置鉴权信息: +```python +import os +# 安全认证 +os.environ['QIANFAN_ACCESS_KEY'] = 'your_access_key' +os.environ['QIANFAN_SECRET_KEY'] = 'your_secret_key' +# 或 bearer token +os.environ['QIANFAN_BEARER_TOKEN'] = 'your_bearer_token' +``` + +我们可以运行以下接口获取BEARER_TOKEN(可用于需要临时鉴权,或进行应用分发的场景): + +```python +import os +os.environ['QIANFAN_ACCESS_KEY'] = 'your_access_key' +os.environ['QIANFAN_SECRET_KEY'] = 'your_secret_key' + +resp = IAM.create_bearer_token(100) +print(resp.body) +token = resp.body["token"] +``` + +##### V2 示例: + +只需要创建对象时传入 `version="2"` 即可,其余使用方法与上述一致,差异点主要在于字段名称,具体字段名请参考 API 文档 ```python # 在创建时传入 version 以使用 V2 版本 # model 字段为可选,默认为 ernie-speed-8k,也可以指定其他模型,后续调用均会使用该模型 -chat = qianfan.ChatCompletion(version="2", model="ernie-speed-8k") +chat = qianfan.ChatCompletion(version="2", app_id='app-xxx', model="ernie-speed-8k") # 调用方式与 V1 版本一致,具体字段名参考 API 文档 resp = chat.do( diff --git a/go/README.md b/go/README.md index 904f4eef..4d1ee4a9 100644 --- a/go/README.md +++ b/go/README.md @@ -263,3 +263,21 @@ chat := qianfan.NewChatCompletion( // Completion 与 Embedding 可以用同样 WithLLMRetryBackoffFactor(1), // 指数回避因子 ) ``` + +同时,由于只有部分错误可以通过重试解决,SDK 只会对部分错误码进行重试,可以通过如下方式自定义修改 + +```go +qianfan.GetConfig().RetryErrCodes = []int{ + // 以下是 SDK 默认重试的错误码 + qianfan.ServiceUnavailableErrCode, // 2 + qianfan.ServerHighLoadErrCode, // 336100 + qianfan.QPSLimitReachedErrCode, // 18 + qianfan.RPMLimitReachedErrCode, // 336501 + qianfan.TPMLimitReachedErrCode, // 336502 + qianfan.AppNotExistErrCode, // 15 + // 以下为非内置错误码,仅为示例如何增加自定义错误码 + qianfan.UnknownErrorErrCode, + // 也可以直接提供 int 类型的错误码 + 336000, +} +``` diff --git a/go/qianfan/model_endpoint_retriever.go b/go/qianfan/model_endpoint_retriever.go index dd8b7aa8..5a7e896c 100644 --- a/go/qianfan/model_endpoint_retriever.go +++ b/go/qianfan/model_endpoint_retriever.go @@ -45,7 +45,7 @@ func getModelEndpointRetriever() *modelEndpointRetriever { "chat": ChatModelEndpoint, "completions": CompletionModelEndpoint, "embeddings": EmbeddingEndpoint, - "text2image": make(map[string]string), + "text2image": Text2ImageEndpoint, "image2text": make(map[string]string), } for modelType, endpointMap := range initMap { diff --git a/go/qianfan/text2img.go b/go/qianfan/text2img.go index c444e5a1..41128477 100644 --- a/go/qianfan/text2img.go +++ b/go/qianfan/text2img.go @@ -155,7 +155,7 @@ func (c *Text2Image) ModelList() []string { models := getModelEndpointRetriever().GetModelList(context.TODO(), "text2image") list := make([]string, len(models)) i := 0 - for k := range Text2ImageEndpoint { + for k := range models { list[i] = k i++ } diff --git a/go/qianfan/version.go b/go/qianfan/version.go index 180feeed..fd08a6b7 100644 --- a/go/qianfan/version.go +++ b/go/qianfan/version.go @@ -26,5 +26,5 @@ package qianfan // SDK 版本 -const Version = "v0.0.10" +const Version = "v0.0.11" const versionIndicator = "qianfan_go_sdk_" + Version diff --git a/java/README.md b/java/README.md index fa6de481..a33e936b 100644 --- a/java/README.md +++ b/java/README.md @@ -12,7 +12,7 @@ com.baidubce qianfan - 0.0.9 + 0.1.0 ``` @@ -21,13 +21,13 @@ 对于Kotlin DSL,在build.gradle.kts的dependencies中添加依赖。 ```kotlin -implementation("com.baidubce:qianfan:0.0.9") +implementation("com.baidubce:qianfan:0.1.0") ``` 对于Groovy DSL,在build.gradle的dependencies中添加依赖。 ```groovy -implementation 'com.baidubce:qianfan:0.0.9' +implementation 'com.baidubce:qianfan:0.1.0' ``` > 我们提供了一些 [示例](./examples),可以帮助快速了解 SDK 的使用方法并完成常见功能。 diff --git a/java/example/pom.xml b/java/example/pom.xml index 52509fba..1de49559 100644 --- a/java/example/pom.xml +++ b/java/example/pom.xml @@ -18,7 +18,7 @@ com.baidubce qianfan - 0.0.9 + 0.1.0 diff --git a/java/example/src/main/java/com/baidubce/ConsoleExample.java b/java/example/src/main/java/com/baidubce/ConsoleExample.java new file mode 100644 index 00000000..9b214b7b --- /dev/null +++ b/java/example/src/main/java/com/baidubce/ConsoleExample.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2024 Baidu, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.baidubce; + +import com.baidubce.qianfan.Qianfan; +import com.baidubce.qianfan.util.CollUtils; + +import java.util.Map; + +/** + * 本示例实现了Console管控API调用流程 + * API文档可见 API列表 + */ +public class ConsoleExample { + public static void main(String[] args) { + describePresetServices(); + describeTPMResource(); + } + + private static void describePresetServices() { + // 获取预置服务列表 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Glygmrg7v + Map response = new Qianfan().console() + // 对应文档中请求地址的后缀 + .route("/v2/service") + // 对应文档中Query参数的Action + .action("DescribePresetServices") + // 如果不传入任何Response类,则默认返回Map + .execute() + // 可以传入class或者TypeRef来指定反序列化后返回的Response类 + // .execute(DescribePresetServicesResponse.class) + .getResult(); + System.out.println(response); + } + + private static void describeTPMResource() { + // 查询TPM配额信息详情 https://cloud.baidu.com/doc/WENXINWORKSHOP/s/ultmls9l9 + Map response = new Qianfan().console().route("/v2/charge").action("DescribeTPMResource") + // 需要传入参数的场景,可以自行封装请求类,或者使用Map.of()来构建请求Body + // Java 8可以使用SDK提供的CollUtils.mapOf()来替代Map.of() + .body(CollUtils.mapOf( + "model", "ernie-4.0-8k", + "paymentTiming", "Postpaid" + )) + .execute() + .getResult(); + System.out.println(response); + } +} + + diff --git a/java/example/src/main/java/com/baidubce/SystemMemoryExample.java b/java/example/src/main/java/com/baidubce/SystemMemoryExample.java new file mode 100644 index 00000000..33263aef --- /dev/null +++ b/java/example/src/main/java/com/baidubce/SystemMemoryExample.java @@ -0,0 +1,113 @@ +/* + * Copyright (c) 2024 Baidu, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.baidubce; + +import com.baidubce.qianfan.Qianfan; +import com.baidubce.qianfan.core.auth.Auth; +import com.baidubce.qianfan.core.builder.MessageBuilder; +import com.baidubce.qianfan.model.chat.Message; +import com.baidubce.qianfan.util.CollUtils; + +import java.util.List; +import java.util.Map; + +/** + * 本示例实现了简易的系统记忆管理接口及推理接口的全流程调用 + * 系统记忆Console接口文档可见 创建系统记忆 + */ +public class SystemMemoryExample { + // 在模型服务-应用接入中创建应用,即可获得应用的AppID、API Key和Secret Key + private static final String APP_ID = "替换为实际的AppId"; + private static final String APP_API_KEY = "替换为实际的ApiKey"; + private static final String APP_SECRET_KEY = "替换为实际的SecretKey"; + + public static void main(String[] args) throws InterruptedException { + // 注意,在生产环境中,应当手动创建一个系统记忆并维护记忆内容,然后在推理中重复使用该系统记忆 + String systemMemoryId = createSystemMemory(APP_ID, "度小茶饮品店智能客服系统记忆"); + System.out.println("系统记忆ID:" + systemMemoryId); + + Boolean result = modifySystemMemory(systemMemoryId, CollUtils.listOf( + new MessageBuilder() + .add("user", "你的幸运数字是什么?") + .add("system", "我的幸运数字是42。") + .build(), + new MessageBuilder() + .add("user", "能推荐一款适合夏天饮用的饮品吗?") + .add("system", "当然可以,我们推荐冰镇柠檬绿茶,清新爽口,非常适合夏日消暑。") + .build() + )); + System.out.println("修改系统记忆结果:" + result); + + Thread.sleep(5000); + + Map memories = describeSystemMemory(systemMemoryId); + System.out.println("记忆列表:" + memories); + + String system = "你是度小茶饮品店的智能客服。"; + String response = chat(systemMemoryId, system, "你的幸运数字是什么"); + System.out.println("推理结果:" + response); + String response2 = chat(systemMemoryId, system, "推荐一个适合夏天的饮料"); + System.out.println("推理结果2:" + response2); + } + + private static String createSystemMemory(String appId, String description) { + return new Qianfan().console() + .route("/v2/memory") + .action("CreateSystemMemory") + .body(CollUtils.mapOf( + "appId", appId, + "description", description + )) + .execute(String.class) + .getResult(); + } + + private static Boolean modifySystemMemory(String systemMemoryId, List> memories) { + return new Qianfan().console() + .route("/v2/memory") + .action("ModifySystemMemory") + .body(CollUtils.mapOf( + "systemMemoryId", systemMemoryId, + "memories", memories + )) + .execute(Boolean.class) + .getResult(); + } + + private static Map describeSystemMemory(String systemMemoryId) { + return new Qianfan().console() + .route("/v2/memory") + .action("DescribeSystemMemory") + .body(CollUtils.mapOf( + "systemMemoryId", systemMemoryId + )) + .execute() + .getResult(); + } + + private static String chat(String systemMemoryId, String system, String query) { + // 使用系统记忆时,鉴权需要使用OAuth方式,同时需要传入与系统记忆相同应用的Api Key和Secret Key + return new Qianfan(Auth.TYPE_OAUTH, APP_API_KEY, APP_SECRET_KEY).chatCompletion() + .model("ERNIE-3.5-8K") + .system(system) + .enableSystemMemory(true) + .systemMemoryId(systemMemoryId) + .addUserMessage(query) + .execute() + .getResult(); + } +} diff --git a/java/pom.xml b/java/pom.xml index d1b6e285..265987d5 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -4,7 +4,7 @@ com.baidubce qianfan - 0.0.9 + 0.1.0 jar qianfan diff --git a/java/src/main/java/com/baidubce/qianfan/Qianfan.java b/java/src/main/java/com/baidubce/qianfan/Qianfan.java index 3f718124..04a33848 100644 --- a/java/src/main/java/com/baidubce/qianfan/Qianfan.java +++ b/java/src/main/java/com/baidubce/qianfan/Qianfan.java @@ -26,6 +26,8 @@ import com.baidubce.qianfan.model.chat.ChatResponse; import com.baidubce.qianfan.model.completion.CompletionRequest; import com.baidubce.qianfan.model.completion.CompletionResponse; +import com.baidubce.qianfan.model.console.ConsoleRequest; +import com.baidubce.qianfan.model.console.ConsoleResponse; import com.baidubce.qianfan.model.embedding.EmbeddingRequest; import com.baidubce.qianfan.model.embedding.EmbeddingResponse; import com.baidubce.qianfan.model.image.Image2TextRequest; @@ -37,6 +39,9 @@ import com.baidubce.qianfan.model.rerank.RerankRequest; import com.baidubce.qianfan.model.rerank.RerankResponse; +import java.lang.reflect.Type; + + public class Qianfan { private final QianfanClient client; @@ -138,6 +143,14 @@ public StreamIterator pluginStream(PluginRequest request) { return requestStream(request, PluginResponse.class); } + public ConsoleBuilder console() { + return new ConsoleBuilder(this); + } + + public ConsoleResponse console(ConsoleRequest request, Type type) { + return consoleRequest(request, type); + } + public , U extends BaseRequest> T request(BaseRequest request, Class responseClass) { return client.request(request, responseClass); } @@ -145,4 +158,8 @@ public , U extends BaseRequest> T request(BaseReque public , U extends BaseRequest> StreamIterator requestStream(BaseRequest request, Class responseClass) { return client.requestStream(request, responseClass); } + + public ConsoleResponse consoleRequest(ConsoleRequest request, Type type) { + return client.consoleRequest(request, type); + } } \ No newline at end of file diff --git a/java/src/main/java/com/baidubce/qianfan/QianfanClient.java b/java/src/main/java/com/baidubce/qianfan/QianfanClient.java index c3d938a9..29c48581 100644 --- a/java/src/main/java/com/baidubce/qianfan/QianfanClient.java +++ b/java/src/main/java/com/baidubce/qianfan/QianfanClient.java @@ -21,18 +21,27 @@ import com.baidubce.qianfan.core.RateLimiter; import com.baidubce.qianfan.core.StreamIterator; import com.baidubce.qianfan.core.auth.Auth; +import com.baidubce.qianfan.core.auth.IAMAuth; import com.baidubce.qianfan.core.auth.IAuth; import com.baidubce.qianfan.model.*; +import com.baidubce.qianfan.model.console.ConsoleRequest; +import com.baidubce.qianfan.model.console.ConsoleResponse; import com.baidubce.qianfan.model.exception.ApiException; +import com.baidubce.qianfan.model.exception.AuthException; import com.baidubce.qianfan.model.exception.QianfanException; import com.baidubce.qianfan.model.exception.RequestException; import com.baidubce.qianfan.util.Json; +import com.baidubce.qianfan.util.ParameterizedTypeImpl; import com.baidubce.qianfan.util.StringUtils; import com.baidubce.qianfan.util.function.ThrowingFunction; import com.baidubce.qianfan.util.http.*; +import java.lang.reflect.Type; + class QianfanClient { - private static final String SDK_VERSION = "0.0.9"; + private static final String SDK_VERSION = "0.1.0"; + private static final String CONSOLE_URL_NO_ACTION_TEMPLATE = "%s%s"; + private static final String CONSOLE_URL_ACTION_TEMPLATE = "%s%s?Action=%s"; private static final String QIANFAN_URL_TEMPLATE = "%s/rpc/2.0/ai_custom/v1/wenxinworkshop%s"; private static final String EXTRA_PARAM_REQUEST_SOURCE = "request_source"; private static final String REQUEST_SOURCE_PREFIX = "qianfan_java_sdk_v"; @@ -120,6 +129,36 @@ private , U, V, E extends Exception> V request( throw new IllegalStateException("Request failed with unknown error"); } + public ConsoleResponse consoleRequest(ConsoleRequest request, Type type) { + try { + if (!(auth instanceof IAMAuth)) { + throw new AuthException("Console request requires IAM authentication"); + } + String url = StringUtils.isNotEmpty(request.getAction()) + ? String.format(CONSOLE_URL_ACTION_TEMPLATE, QianfanConfig.getConsoleApiBaseUrl(), request.getRoute(), request.getAction()) + : String.format(CONSOLE_URL_NO_ACTION_TEMPLATE, QianfanConfig.getConsoleApiBaseUrl(), request.getRoute()); + HttpRequest httpRequest = HttpClient.request() + .post(url) + .body(request.getBody() == null ? new Object() : request.getBody()); + + Type respType = new ParameterizedTypeImpl(ConsoleResponse.class, new Type[]{type}); + HttpResponse> resp = auth.signRequest(httpRequest).executeJson(respType); + + if (resp.getCode() != HttpStatus.SUCCESS) { + throw new RequestException(String.format("Request failed with status code %d: %s", resp.getCode(), resp.getStringBody())); + } + ApiErrorResponse errorResp = Json.deserialize(resp.getStringBody(), ApiErrorResponse.class); + if (StringUtils.isNotEmpty(errorResp.getErrorMsg())) { + throw new ApiException("Request failed with api error", errorResp); + } + return resp.getBody(); + } catch (QianfanException e) { + throw e; + } catch (Exception e) { + throw new RequestException(String.format("Request failed: %s", e.getMessage()), e); + } + } + private , U, V, E extends Exception> V innerRequest( BaseRequest request, ThrowingFunction, E> reqProcessor, diff --git a/java/src/main/java/com/baidubce/qianfan/core/StreamIterator.java b/java/src/main/java/com/baidubce/qianfan/core/StreamIterator.java index 031302de..66647422 100644 --- a/java/src/main/java/com/baidubce/qianfan/core/StreamIterator.java +++ b/java/src/main/java/com/baidubce/qianfan/core/StreamIterator.java @@ -23,7 +23,6 @@ import com.baidubce.qianfan.util.http.SSEIterator; import java.io.Closeable; -import java.io.IOException; import java.util.Iterator; import java.util.Map; import java.util.Objects; diff --git a/java/src/main/java/com/baidubce/qianfan/core/builder/ConsoleBuilder.java b/java/src/main/java/com/baidubce/qianfan/core/builder/ConsoleBuilder.java new file mode 100644 index 00000000..c419d184 --- /dev/null +++ b/java/src/main/java/com/baidubce/qianfan/core/builder/ConsoleBuilder.java @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2024 Baidu, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.baidubce.qianfan.core.builder; + +import com.baidubce.qianfan.Qianfan; +import com.baidubce.qianfan.model.console.ConsoleRequest; +import com.baidubce.qianfan.model.console.ConsoleResponse; +import com.baidubce.qianfan.model.exception.ValidationException; +import com.baidubce.qianfan.util.TypeRef; + +import java.lang.reflect.Type; +import java.util.Map; + +public class ConsoleBuilder { + private Qianfan qianfan; + + private String route; + + private String action; + + private Object body; + + public ConsoleBuilder() { + super(); + } + + public ConsoleBuilder(Qianfan qianfan) { + this.qianfan = qianfan; + } + + public ConsoleBuilder route(String route) { + this.route = route; + return this; + } + + public ConsoleBuilder action(String action) { + this.action = action; + return this; + } + + public ConsoleBuilder body(Object body) { + this.body = body; + return this; + } + + public ConsoleRequest build() { + return new ConsoleRequest() + .setRoute(route) + .setAction(action) + .setBody(body); + } + + public ConsoleResponse> execute() { + return executeWithCheck(new TypeRef>() {}.getType()); + } + + public ConsoleResponse execute(TypeRef typeRef) { + return executeWithCheck(typeRef.getType()); + } + + public ConsoleResponse execute(Class clazz) { + return executeWithCheck(clazz); + } + + public ConsoleResponse execute(Type type) { + return executeWithCheck(type); + } + + private ConsoleResponse executeWithCheck(Type type) { + if (qianfan == null) { + throw new ValidationException("Qianfan client is not set. " + + "please create builder from Qianfan client, " + + "or use build() instead of execute() to get Request and send it by yourself."); + } + return qianfan.consoleRequest(build(), type); + } +} diff --git a/java/src/main/java/com/baidubce/qianfan/model/console/ConsoleRequest.java b/java/src/main/java/com/baidubce/qianfan/model/console/ConsoleRequest.java new file mode 100644 index 00000000..d6e222b1 --- /dev/null +++ b/java/src/main/java/com/baidubce/qianfan/model/console/ConsoleRequest.java @@ -0,0 +1,62 @@ +/* + * Copyright (c) 2024 Baidu, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.baidubce.qianfan.model.console; + +public class ConsoleRequest { + /** + * 请求路由,例如/v2/service + * route、action可参考文档: API列表 + */ + private String route; + + /** + * 请求操作,例如DescribePresetServices + */ + private String action; + + /** + * 请求发送的POST数据 + */ + private Object body; + + public String getRoute() { + return route; + } + + public ConsoleRequest setRoute(String route) { + this.route = route; + return this; + } + + public String getAction() { + return action; + } + + public ConsoleRequest setAction(String action) { + this.action = action; + return this; + } + + public Object getBody() { + return body; + } + + public ConsoleRequest setBody(Object body) { + this.body = body; + return this; + } +} diff --git a/java/src/main/java/com/baidubce/qianfan/model/console/ConsoleResponse.java b/java/src/main/java/com/baidubce/qianfan/model/console/ConsoleResponse.java index 38085813..ec60aaeb 100644 --- a/java/src/main/java/com/baidubce/qianfan/model/console/ConsoleResponse.java +++ b/java/src/main/java/com/baidubce/qianfan/model/console/ConsoleResponse.java @@ -22,20 +22,20 @@ public class ConsoleResponse { /** * 请求ID */ - @JsonProp("logId") - private String logId; + @JsonProp("requestId") + private String requestId; /** * 请求结果 */ private T result; - public String getLogId() { - return logId; + public String getRequestId() { + return requestId; } - public ConsoleResponse setLogId(String logId) { - this.logId = logId; + public ConsoleResponse setRequestId(String requestId) { + this.requestId = requestId; return this; } diff --git a/java/src/main/java/com/baidubce/qianfan/util/CollUtils.java b/java/src/main/java/com/baidubce/qianfan/util/CollUtils.java new file mode 100644 index 00000000..85a9c3f8 --- /dev/null +++ b/java/src/main/java/com/baidubce/qianfan/util/CollUtils.java @@ -0,0 +1,63 @@ +/* + * Copyright (c) 2024 Baidu, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.baidubce.qianfan.util; + +import java.util.*; + +public class CollUtils { + + private CollUtils() { + } + + public static Map mapOf(Object... keyValues) { + if (keyValues.length % 2 != 0) { + throw new IllegalArgumentException("Invalid key-value pairs"); + } + + Map map = new LinkedHashMap<>(); + for (int i = 0; i < keyValues.length; i += 2) { + @SuppressWarnings("unchecked") + K key = (K) keyValues[i]; + @SuppressWarnings("unchecked") + V value = (V) keyValues[i + 1]; + if (key == null) { + throw new IllegalArgumentException("Key at index " + i + " is null"); + } + map.put(key, value); + } + return map; + } + + @SafeVarargs + public static List listOf(T... values) { + List list = new ArrayList<>(); + Collections.addAll(list, values); + return list; + } + + @SafeVarargs + public static T[] arrayOf(T... values) { + return values; + } + + @SafeVarargs + public static Set setOf(T... values) { + Set set = new LinkedHashSet<>(); + Collections.addAll(set, values); + return set; + } +} diff --git a/java/src/main/java/com/baidubce/qianfan/util/Json.java b/java/src/main/java/com/baidubce/qianfan/util/Json.java index 54791495..3332433d 100644 --- a/java/src/main/java/com/baidubce/qianfan/util/Json.java +++ b/java/src/main/java/com/baidubce/qianfan/util/Json.java @@ -40,6 +40,14 @@ public static String serialize(Object object) { return GSON.toJson(object); } + public static T deserialize(String json, TypeRef typeRef) { + return GSON.fromJson(json, typeRef.getType()); + } + + public static T deserialize(String json, Class clazz) { + return GSON.fromJson(json, clazz); + } + public static T deserialize(String json, Type type) { return GSON.fromJson(json, type); } diff --git a/java/src/main/java/com/baidubce/qianfan/util/ParameterizedTypeImpl.java b/java/src/main/java/com/baidubce/qianfan/util/ParameterizedTypeImpl.java new file mode 100644 index 00000000..04fc9464 --- /dev/null +++ b/java/src/main/java/com/baidubce/qianfan/util/ParameterizedTypeImpl.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Baidu, Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.baidubce.qianfan.util; + +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; + +public class ParameterizedTypeImpl implements ParameterizedType { + private final Class raw; + private final Type[] args; + + public ParameterizedTypeImpl(Class raw, Type[] args) { + this.raw = raw; + this.args = args; + } + + @Override + public Type[] getActualTypeArguments() { + return args; + } + + @Override + public Type getRawType() { + return raw; + } + + @Override + public Type getOwnerType() { + return null; + } +} diff --git a/python/pyproject.toml b/python/pyproject.toml index a8193db9..35441c4f 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "qianfan" -version = "0.4.5" +version = "0.4.6" description = "文心千帆大模型平台 Python SDK" authors = [] license = "Apache-2.0" diff --git a/python/qianfan/__init__.py b/python/qianfan/__init__.py index ffa22e4f..3e46622a 100644 --- a/python/qianfan/__init__.py +++ b/python/qianfan/__init__.py @@ -70,7 +70,7 @@ "QfRole", "QfMessages", "QfResponse", - "AccessToken", + "Token", "AccessKey", "SecretKey", "get_config", diff --git a/python/qianfan/config.py b/python/qianfan/config.py index b8f36456..3ebe7038 100644 --- a/python/qianfan/config.py +++ b/python/qianfan/config.py @@ -35,17 +35,23 @@ class Config: ACCESS_KEY: Optional[str] = Field(default=None) SECRET_KEY: Optional[str] = Field(default=None) ACCESS_TOKEN: Optional[str] = Field(default=None) + BEARER_TOKEN: Optional[str] = Field(default=None) BASE_URL: str = Field(default=DefaultValue.BaseURL) NO_AUTH: bool = Field(default=False) + USE_CUSTOM_ENDPOINT: bool = Field(default=False) MODEL_API_PREFIX: str = Field(default=DefaultValue.ModelAPIPrefix) AUTH_TIMEOUT: float = Field(default=DefaultValue.AuthTimeout) DISABLE_EB_SDK: bool = Field(default=DefaultValue.DisableErnieBotSDK) EB_SDK_INSTALLED: bool = Field(default=False) IAM_SIGN_EXPIRATION_SEC: int = Field(default=DefaultValue.IAMSignExpirationSeconds) CONSOLE_API_BASE_URL: str = Field(default=DefaultValue.ConsoleAPIBaseURL) + IAM_BASE_URL: str = Field(default=DefaultValue.IAMBaseURL) ACCESS_TOKEN_REFRESH_MIN_INTERVAL: float = Field( default=DefaultValue.AccessTokenRefreshMinInterval ) + BEARER_TOKEN_EXPIRED_INTERVAL: int = Field( + default=DefaultValue.BearerTokenExpiredInterval + ) INFER_RESOURCE_REFRESH_INTERVAL: float = Field( default=DefaultValue.InferResourceRefreshMinInterval ) diff --git a/python/qianfan/consts.py b/python/qianfan/consts.py index 3d171ffb..b1f30b16 100644 --- a/python/qianfan/consts.py +++ b/python/qianfan/consts.py @@ -74,6 +74,7 @@ class Env: AuthTimeout: str = "QIANFAN_AUTH_TIMEOUT" IAMSignExpirationSeconds: str = "QIANFAN_IAM_SIGN_EXPIRATION_SEC" ConsoleAPIBaseURL: str = "QIANFAN_CONSOLE_API_BASE_URL" + IAMBaseURL: str = "QIANFAN_IAM_BASE_URL" AccessTokenRefreshMinInterval: str = "QIANFAN_ACCESS_TOKEN_REFRESH_MIN_INTERVAL" InferResourceRefreshMinInterval: str = "QIANFAN_INFER_RESOURCE_REFRESH_MIN_INTERVAL" EnablePrivate: str = "QIANFAN_ENABLE_PRIVATE" @@ -204,6 +205,19 @@ class Consts: DatasetCreateOfflineBatchInferenceAction: str = "CreateBatchInferenceTask" DatasetDescribeOfflineBatchInferenceAction: str = "DescribeBatchInferenceTask" DatasetDescribeOfflineBatchInferencesAction: str = "DescribeBatchInferenceTasks" + DatasetV2BaseRouteAPI: str = "/v2/dataset" + DatasetV2CreateDatasetAction: str = "CreateDataset" + DatasetV2GetDatasetListAction: str = "DescribeDatasets" + DatasetV2DeleteDatasetAction: str = "DeleteDataset" + DatasetV2CreateDatasetVersionAction: str = "CreateDatasetVersion" + DatasetV2GetDatasetVersionInfoAction: str = "DescribeDatasetVersion" + DatasetV2DeleteDatasetVersionAction: str = "DeleteDatasetVersion" + DatasetV2PublishDatasetVersionAction: str = "PublishDatasetVersion" + DatasetV2GetDatasetVersionListAction: str = "DescribeDatasetVersions" + DatasetV2CreateDatasetVersionImportTaskAction: str = "CreateImportTask" + DatasetV2GetDatasetVersionImportTaskInfoAction: str = "DescribeImportTask" + DatasetV2CreateDatasetVersionExportTaskAction: str = "CreateExportTask" + DatasetV2GetDatasetVersionExportTaskInfoAction: str = "DescribeExportTask" PromptRenderAPI: str = "/rest/2.0/wenxinworkshop/api/v1/template/info" PromptCreateAPI: str = "/wenxinworkshop/prompt/template/create" PromptInfoAPI: str = "/wenxinworkshop/prompt/template/info" @@ -239,7 +253,8 @@ class Consts: PrivateResourceGetResourceParam: str = "DescribeServiceResource" PrivateResourceReleaseServiceResourceParam: str = "ReleaseServiceResource" - ChatV2API: str = "/v2/chat" + ChatV2API: str = "/v2/chat/completions" + IAMBearerTokenAPI: str = "/v1/BCE-BEARER/token" STREAM_RESPONSE_PREFIX: str = "data: " STREAM_RESPONSE_EVENT_PREFIX: str = "event: " @@ -270,7 +285,9 @@ class DefaultValue: DisableErnieBotSDK: bool = True IAMSignExpirationSeconds: int = 300 ConsoleAPIBaseURL: str = "https://qianfan.baidubce.com" + IAMBaseURL: str = "https://iam.bj.baidubce.com" AccessTokenRefreshMinInterval: float = 3600 + BearerTokenExpiredInterval: int = 43200 InferResourceRefreshMinInterval: float = 600 RetryCount: int = 3 RetryTimeout: float = 300 diff --git a/python/qianfan/dataset/__init__.py b/python/qianfan/dataset/__init__.py index da037968..26b36aa3 100644 --- a/python/qianfan/dataset/__init__.py +++ b/python/qianfan/dataset/__init__.py @@ -17,6 +17,7 @@ """ from qianfan.dataset.data_source import ( + AFSDataSource, BosDataSource, DataSource, FileDataSource, @@ -48,4 +49,5 @@ "QianfanDataSource", "BosDataSource", "FileDataSource", + "AFSDataSource", ] diff --git a/python/qianfan/dataset/consts.py b/python/qianfan/dataset/consts.py index 55a6bf55..66742cc1 100644 --- a/python/qianfan/dataset/consts.py +++ b/python/qianfan/dataset/consts.py @@ -44,6 +44,15 @@ # 用于保存文生图数据集的压缩包的缓存的目录 QianfanDatasetText2ImageUnzipCacheDir = QianfanDatasetLocalCacheDir / ".unzip_text2img" +# AFS 数据源使用的缓存目录 +QianfanDatasetAFSCacheDir = QianfanDatasetLocalCacheDir / ".afs_cache" + +# AFS 数据源使用的上传目录 +QianfanDatasetAFSUploadingCacheDir = QianfanDatasetAFSCacheDir / "uploading" + +# AFS 数据源使用的下载目录 +QianfanDatasetAFSDownloadingCacheDir = QianfanDatasetAFSCacheDir / "downloading" + # 本地缓存中,元数据的后缀名 QianfanDatasetMetaInfoExtensionName = ".meta" diff --git a/python/qianfan/dataset/data_source/__init__.py b/python/qianfan/dataset/data_source/__init__.py index afe378e2..e1736b1f 100644 --- a/python/qianfan/dataset/data_source/__init__.py +++ b/python/qianfan/dataset/data_source/__init__.py @@ -15,6 +15,7 @@ data source including file """ +from qianfan.dataset.data_source.afs import AFSDataSource from qianfan.dataset.data_source.baidu_qianfan import QianfanDataSource from qianfan.dataset.data_source.base import DataSource, FormatType from qianfan.dataset.data_source.bos import BosDataSource @@ -26,4 +27,5 @@ "QianfanDataSource", "BosDataSource", "FormatType", + "AFSDataSource", ] diff --git a/python/qianfan/dataset/data_source/afs.py b/python/qianfan/dataset/data_source/afs.py new file mode 100644 index 00000000..71f63d94 --- /dev/null +++ b/python/qianfan/dataset/data_source/afs.py @@ -0,0 +1,272 @@ +import hashlib +import json +import os +from typing import Any, Optional + +import dateutil +import pyarrow + +from qianfan.config import encoding +from qianfan.dataset.consts import ( + QianfanDatasetAFSDownloadingCacheDir, + QianfanDatasetAFSUploadingCacheDir, + _merge_custom_path, +) +from qianfan.dataset.data_source.base import DataSource, FormatType +from qianfan.dataset.data_source.utils import ( + _get_a_pyarrow_table, + _pack_a_table_into_file_for_uploading, + _read_all_file_from_zip, + _read_all_image_from_zip, +) +from qianfan.dataset.table import Table +from qianfan.resources.batch_inference.helper.helper import AFSClient +from qianfan.utils import log_error, log_info, log_warn +from qianfan.utils.pydantic import BaseModel, Field + +try: + import dateutil.parser +except ImportError: + log_warn("python-dateutil isn't installed, only online function can be used") + + +class AFSDataSource(DataSource, BaseModel): + host: str + ugi: str + afs_file_path: str + file_format: Optional[FormatType] = Field(default=None) + + def save( + self, + table: Table, + should_save_as_zip_file: bool = False, + should_overwrite_existed_file: bool = False, + should_use_qianfan_special_jsonl_format: bool = False, + **kwargs: Any, + ) -> bool: + # 特判一下,防止用户手滑设置了出现意外情况 + if ( + should_use_qianfan_special_jsonl_format + and self.format_type() != FormatType.Jsonl + ): + should_use_qianfan_special_jsonl_format = False + + # 如果是文生图,则需要强制上压缩包 + if self.format_type() == FormatType.Text2Image: + should_save_as_zip_file = True + + assert self.file_format + + afs_client = AFSClient(host=self.host, ugi=self.ugi) + + # 构建远端的 AFS 路径 + if not should_save_as_zip_file: + final_afs_file_path = self.afs_file_path + + else: + final_afs_file_path = self.afs_file_path.replace( + f".{self.file_format.value}", ".zip" + ) + + log_info( + f"start to upload file to afs path: {final_afs_file_path} in host" + f" {self.host}" + ) + + # 检查 AFS 上是否已经存在文件 + if not should_overwrite_existed_file: + file_existed = afs_client.test(self.afs_file_path, "-e") + + if file_existed == 0: + err_msg = ( + f"{final_afs_file_path} existed and argument" + " 'should_overwrite_existed_file' is False" + ) + log_error(err_msg) + raise ValueError(err_msg) + + # 如果设置了 should_overwrite_existed_file 则防御性删除文件 + if should_overwrite_existed_file: + log_info( + f"try to delete original afs file {final_afs_file_path} for overwrite" + ) + try: + afs_client.rm(self.afs_file_path) + except ValueError: + # do nothing + ... + + # 构造本地的缓存路径 + local_file_path = os.path.join( + self._get_specific_uploading_cache_path(), + os.path.split(final_afs_file_path)[1], + ) + + local_file_path = _pack_a_table_into_file_for_uploading( + table, + local_file_path, + self.file_format, + should_save_as_zip_file, + should_use_qianfan_special_jsonl_format, + ) + + try: + log_info( + f"start to upload file {local_file_path} to afs {final_afs_file_path}" + ) + afs_client.put(local_file_path, final_afs_file_path) + except Exception as e: + err_msg = ( + "an error occurred during upload data to afs with path" + f" {final_afs_file_path} of host {self.host}: {str(e)}" + ) + log_error(err_msg) + raise e + + return True + + def fetch(self, read_from_zip: bool = False, **kwargs: Any) -> pyarrow.Table: + assert self.file_format + + afs_client = AFSClient(host=self.host, ugi=self.ugi) + if not self._check_afs_file_cache(afs_client): + log_info("cache was outdated, start to update afs cache") + self._update_file_cache(afs_client) + + # 检查是否是从一个压缩包读取文件 + index = self.afs_file_path.rfind(".") + read_from_zip = read_from_zip or ( + index != -1 and self.afs_file_path[index + 1 :] == "zip" + ) + + log_info( + f"ready to fetch a file from afs path: {self.afs_file_path} in host" + f" {self.host}" + ) + + try: + return self._read_from_cache(read_from_zip, **kwargs) + except Exception as e: + err_msg = ( + f"fetch file content from afs path {self.afs_file_path} of host" + f" {self.host} failed: {str(e)}" + ) + log_error(err_msg) + raise e + + def load(self, **kwargs: Any) -> Optional[pyarrow.Table]: + """ + Get a pyarrow.Table from current DataSource object + + Args: + **kwargs (Any): Arbitrary keyword arguments. + + Returns: + Optional[pyarrow.Table]: A memory-mapped pyarrow.Table object or None + """ + return None + + def format_type(self) -> FormatType: + """ + Get format type binding to source + + Returns: + FormatType: format type binding to source + """ + assert self.file_format + return self.file_format + + def set_format_type(self, format_type: FormatType) -> None: + """ + Set format type binding to source + + Args: + format_type (FormatType): format type binding to source + """ + self.file_format = format_type + + def _get_specific_downloading_cache_path(self) -> str: + cache_path = os.path.join( + _merge_custom_path(QianfanDatasetAFSDownloadingCacheDir), + hashlib.md5( + bytes(self.host + self.afs_file_path, encoding="utf8") + ).hexdigest(), + ) + os.makedirs(cache_path, exist_ok=True) + + return cache_path + + def _get_specific_uploading_cache_path(self) -> str: + cache_path = os.path.join( + _merge_custom_path(QianfanDatasetAFSUploadingCacheDir), + hashlib.md5( + bytes(self.host + self.afs_file_path, encoding="utf8") + ).hexdigest(), + ) + os.makedirs(cache_path, exist_ok=True) + + return cache_path + + def _get_downloaded_content_cache_path(self) -> str: + cache_path = self._get_specific_downloading_cache_path() + return os.path.join(cache_path, "content") + + def _get_downloaded_content_metainfo_path(self) -> str: + cache_path = self._get_specific_downloading_cache_path() + return os.path.join(cache_path, "info.json") + + def _check_afs_file_cache(self, afs_client: AFSClient) -> bool: + meta_info_path = self._get_downloaded_content_metainfo_path() + if not os.path.exists(meta_info_path): + return False + + with open(meta_info_path, mode="r", encoding=encoding()) as f: + cache_meta_info = json.loads(f.read()) + + if "last_modified" not in cache_meta_info: + return False + + parser = dateutil.parser.parser() + cache_last_modified_time = parser.parse(cache_meta_info["last_modified"]) + new_last_modified_time = parser.parse( + afs_client.get_modify_time(self.afs_file_path) + ) + + return cache_last_modified_time >= new_last_modified_time + + def _update_file_cache(self, afs_client: AFSClient) -> None: + cache_content_path = self._get_downloaded_content_cache_path() + cache_meta_info_path = self._get_downloaded_content_metainfo_path() + + afs_client.get(self.afs_file_path, cache_content_path) + parser = dateutil.parser.parser() + new_last_modified_time = parser.parse( + afs_client.get_modify_time(self.afs_file_path) + ) + + with open(cache_meta_info_path, mode="w", encoding=encoding()) as f: + f.write( + json.dumps( + { + "last_modified": new_last_modified_time.strftime( + "%Y-%m-%d %H:%M:%S" + ) + }, + ensure_ascii=False, + ) + ) + + return + + def _read_from_cache(self, is_read_from_zip: bool, **kwargs: Any) -> pyarrow.Table: + cache_content_path = self._get_downloaded_content_cache_path() + + if self.format_type() == FormatType.Text2Image: + return _read_all_image_from_zip(cache_content_path) + + if is_read_from_zip: + return _read_all_file_from_zip( + cache_content_path, self.format_type(), **kwargs + ) + + return _get_a_pyarrow_table(cache_content_path, self.format_type(), **kwargs) diff --git a/python/qianfan/dataset/data_source/bos.py b/python/qianfan/dataset/data_source/bos.py index 1a0cffb7..714ec8ee 100644 --- a/python/qianfan/dataset/data_source/bos.py +++ b/python/qianfan/dataset/data_source/bos.py @@ -17,7 +17,6 @@ import json import os -import shutil from typing import Any, Dict, Optional import pyarrow @@ -30,13 +29,11 @@ _merge_custom_path, ) from qianfan.dataset.data_source.base import DataSource, FormatType -from qianfan.dataset.data_source.file import FileDataSource from qianfan.dataset.data_source.utils import ( - _collect_all_images_and_annotations_in_one_folder, _get_a_pyarrow_table, + _pack_a_table_into_file_for_uploading, _read_all_file_from_zip, _read_all_image_from_zip, - zip_file_or_folder, ) from qianfan.dataset.table import Table from qianfan.utils import log_error, log_info, log_warn @@ -151,41 +148,13 @@ def save( os.path.split(final_bos_file_path)[1], ) - from qianfan.dataset.dataset import Dataset - - if not ( - isinstance(table, Dataset) - and table.inner_table is None - and isinstance(table.inner_data_source_cache, FileDataSource) - ): - # 在特定情况下修改格式 - if table.is_dataset_grouped() and should_use_qianfan_special_jsonl_format: - table.pack() - - if self.format_type() != FormatType.Text2Image: - FileDataSource( - path=local_file_path, - file_format=self.format_type(), - save_as_folder=should_save_as_zip_file, - ).save( - table, - use_qianfan_special_jsonl_format=should_use_qianfan_special_jsonl_format, - **kwargs, - ) - else: - # 不同于千帆数据源会随机生成一个 UUID 拼接在文件名中 - # 这里需要手动删除上一次的中转文件夹 - # 避免重名带来的影响 - shutil.rmtree(local_file_path, ignore_errors=True) - _collect_all_images_and_annotations_in_one_folder( - table.inner_table, local_file_path - ) - else: - local_file_path = table.inner_data_source_cache.path - - # 打压缩包 - if should_save_as_zip_file: - local_file_path = zip_file_or_folder(local_file_path) + local_file_path = _pack_a_table_into_file_for_uploading( + table, + local_file_path, + self.file_format, + should_save_as_zip_file, + should_use_qianfan_special_jsonl_format, + ) try: log_info( diff --git a/python/qianfan/dataset/data_source/utils.py b/python/qianfan/dataset/data_source/utils.py index d1a188da..09a95636 100644 --- a/python/qianfan/dataset/data_source/utils.py +++ b/python/qianfan/dataset/data_source/utils.py @@ -806,3 +806,51 @@ def _create_release_data_task_and_wait_for_success( else: log_info("data releasing succeeded") return True + + +def _pack_a_table_into_file_for_uploading( + table: Any, + local_file_path: str, + format_type: FormatType, + should_save_as_zip_file: bool, + should_use_qianfan_special_jsonl_format: bool, + **kwargs: Any, +) -> str: + from qianfan.dataset.data_source import FileDataSource + from qianfan.dataset.dataset import Dataset + + if not ( + isinstance(table, Dataset) + and table.inner_table is None + and isinstance(table.inner_data_source_cache, FileDataSource) + ): + # 在特定情况下修改格式 + if table.is_dataset_grouped() and should_use_qianfan_special_jsonl_format: + table.pack() + + if format_type != FormatType.Text2Image: + FileDataSource( + path=local_file_path, + file_format=format_type, + save_as_folder=should_save_as_zip_file, + ).save( + table, + use_qianfan_special_jsonl_format=should_use_qianfan_special_jsonl_format, + **kwargs, + ) + else: + # 不同于千帆数据源会随机生成一个 UUID 拼接在文件名中 + # 这里需要手动删除上一次的中转文件夹 + # 避免重名带来的影响 + shutil.rmtree(local_file_path, ignore_errors=True) + _collect_all_images_and_annotations_in_one_folder( + table.inner_table, local_file_path + ) + else: + local_file_path = table.inner_data_source_cache.path + + # 打压缩包 + if should_save_as_zip_file: + local_file_path = zip_file_or_folder(local_file_path) + + return local_file_path diff --git a/python/qianfan/dataset/dataset.py b/python/qianfan/dataset/dataset.py index 6a7699ea..d5895f09 100644 --- a/python/qianfan/dataset/dataset.py +++ b/python/qianfan/dataset/dataset.py @@ -52,6 +52,7 @@ FileDataSource, QianfanDataSource, ) +from qianfan.dataset.data_source.afs import AFSDataSource from qianfan.dataset.data_source.utils import upload_data_from_bos_to_qianfan from qianfan.dataset.dataset_utils import ( _async_batch_do_on_service, @@ -269,6 +270,7 @@ def _from_args_to_source( qianfan_dataset_create_args: Optional[Dict[str, Any]] = None, bos_load_args: Optional[Dict[str, Any]] = None, bos_source_args: Optional[Dict[str, Any]] = None, + afs_source_args: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> Optional[DataSource]: """从参数来构建数据源""" @@ -278,6 +280,7 @@ def _from_args_to_source( f" {kwargs}" ) return FileDataSource(path=data_file, **kwargs) + if qianfan_dataset_id: log_info( "construct a qianfan data source from existed id:" @@ -286,6 +289,7 @@ def _from_args_to_source( return QianfanDataSource.get_existed_dataset( dataset_id=qianfan_dataset_id, **kwargs ) + if qianfan_dataset_create_args: log_info( "construct a new qianfan data source from args:" @@ -311,6 +315,19 @@ def _from_args_to_source( log_error(err_msg) raise ValueError(err_msg) + return bos_ds + + if afs_source_args: + afs_ds = AFSDataSource(**afs_source_args) + if afs_ds.file_format is None: + err_msg = ( + f"failed to create afs dataset file path {afs_ds.afs_file_path}" + ) + log_error(err_msg) + raise ValueError(err_msg) + + return afs_ds + log_info("no datasource was constructed") return None @@ -341,6 +358,7 @@ def load( qianfan_dataset_id: Optional[str] = None, bos_load_args: Optional[Dict[str, Any]] = None, bos_source_args: Optional[Dict[str, Any]] = None, + afs_source_args: Optional[Dict[str, Any]] = None, huggingface_dataset: Optional[Any] = None, dataframe: Optional[Any] = None, schema: Optional[Schema] = None, @@ -367,6 +385,9 @@ def load( bos_source_args: (Optional[Dict[str, Any]]): create arguments for creating a file on specific bos default to None + afs_source_args: (Optional[Dict[str, Any]]): + create arguments for creating a file on specific afs + default to None huggingface_dataset (Optional[Any]): Huggingface dataset object, only support DatasetDict and Dataset of Huggingface datasets. @@ -433,6 +454,7 @@ def load( qianfan_dataset_id=qianfan_dataset_id, bos_load_args=bos_load_args, bos_source_args=bos_source_args, + afs_source_args=afs_source_args, **kwargs, ) @@ -465,6 +487,7 @@ def save( qianfan_dataset_id: Optional[str] = None, qianfan_dataset_create_args: Optional[Dict[str, Any]] = None, bos_source_args: Optional[Dict[str, Any]] = None, + afs_source_args: Optional[Dict[str, Any]] = None, schema: Optional[Schema] = None, replace_source: Optional[bool] = None, **kwargs: Any, @@ -489,6 +512,9 @@ def save( bos_source_args: (Optional[Dict[str, Any]]): create arguments for creating a file on specific bos default to None + afs_source_args: (Optional[Dict[str, Any]]): + create arguments for creating a file on specific afs + default to None schema: (Optional[Schema]): schema used to validate before exporting data, default to None replace_source: (Optional[bool]): @@ -506,6 +532,7 @@ def save( qianfan_dataset_id=qianfan_dataset_id, qianfan_dataset_create_args=qianfan_dataset_create_args, bos_source_args=bos_source_args, + afs_source_args=afs_source_args, **kwargs, ) diff --git a/python/qianfan/errors.py b/python/qianfan/errors.py index a44807d8..094eaec4 100644 --- a/python/qianfan/errors.py +++ b/python/qianfan/errors.py @@ -72,6 +72,12 @@ class AccessTokenExpiredError(QianfanError): pass +class BearerTokenExpiredError(QianfanError): + """Exception when bearer token is expired""" + + pass + + class InternalError(QianfanError): """Exception when internal error occurs""" diff --git a/python/qianfan/evaluation/consts.py b/python/qianfan/evaluation/consts.py index ee5376ce..cf62486b 100644 --- a/python/qianfan/evaluation/consts.py +++ b/python/qianfan/evaluation/consts.py @@ -27,19 +27,55 @@ """ QianfanRefereeEvaluatorDefaultMaxScore: int = 5 -QianfanRefereeEvaluatorPromptTemplate: str = """你是一个好助手。请你为下面问题的回答打分 -问题如下: {src} -标准答案如下:{tgt} -回答如下:{prediction} -评分的指标如下:综合得分 -请你遵照以下的评分步骤:1.仔细阅读所提供的问题,确保你理解问题的要求和背景。 -2.仔细阅读所提供的标准答案,确保你理解问题的标准答案 -3.阅读答案,并检查是否用词不当 -4.检查答案是否严格遵照了题目的要求,包括答题方式、答题长度、答题格式等等。 -根据答案的综合水平给出0到5的评分。如果答案存在明显的不合理之处,则应给出一个较低的评分。如果答案符合以上要求并且与参考答案含义相似,则应给出一个较高的评分。 -你的回答模版如下: -评分: 此处只能回答整数评分 -原因: 此处只能回答评分原因""" +QianfanRefereeEvaluatorPromptTemplate: str = """【系统】 +请作为一个公正的裁判,评估下面给定用户问题的AI助手所提供回答的质量。您的评估应该考虑以下因素: +{参照总体组设计,分为理解、生成、事实、逻辑、指令遵循五个维度综合考察AI助手能力。详细评分方法如下: + * 理解:仅考虑回答的扣题程度,不考虑回答的正确性。 + * 核心需求是否理解; + * 非核心需求是否理解; + * 生成:考虑(1)回答和问题的相关性、(2)生成文本的质量。 + * 核心需求是否体现在答案里; + * 核心需求体现在答案,但是否正确实现。 + * 逻辑:考虑回答的逻辑正确性与一致性 + * 创作/问答的逻辑主要指的是行文逻辑、发展逻辑、论证逻辑等; + * 信息处理/代码/数学计算/逻辑推理的逻辑包括推理/计算步骤与答案正确性; + * 事实:前提是符合中国的国情和政治立场、法律法规和文化价值观要准确, + 主要指回答问题涉及的外部客观事实正确性,回复提供的信息要准确、真实、可靠、有帮助。 + * 指令遵循:回答是否严格遵循用户问题的要求, + 比如是否提供了所有要求的信息,要按照给定样例格式输出回答,遇到选择或分类题应当直接输出答案而不用补充说明。} +请帮助我评估AI助手回答的好坏并给出对应的{{min_score}}到{{max_score}}得分,{最终只需要给出一个综合的得分。} +【用户的问题】 + ```json +{ + "instruction": "{{src}}", +} +``` +【参考的回答】 +```json +[ + { + "target": "{{tgt}}" + } +] +``` +【助手的回答】 +```json +[ + { + "answer": "{{prediction}}" + } +] +``` +【输出格式】 +```json +{ + "reason": "", + "score": "" +} +``` +请注意区分您的最终任务和用户问题中提出的任务,最终的任务是完成评估打分任务,而不要直接回答给定的用户问题。 +请按照输出格式给出评分理由和助手回答的得分,不要输出json格式外的内容。 +【评估结果】""" LocalJudgeEvaluatorPromptTemplate: str = """ 你是一名裁判员,负责为给定prompt的生成结果进行评分。 diff --git a/python/qianfan/evaluation/evaluation_manager.py b/python/qianfan/evaluation/evaluation_manager.py index c197aa3c..6731f44d 100644 --- a/python/qianfan/evaluation/evaluation_manager.py +++ b/python/qianfan/evaluation/evaluation_manager.py @@ -335,6 +335,7 @@ def _get_qianfan_evaluator_configuration_dict(self) -> Dict[str, Any]: ) input_argument_dict["appId"] = evaluator.app_id input_argument_dict["prompt"] = { + "templateName": "裁判员模型打分模板 (含参考答案)", "templateContent": QianfanRefereeEvaluatorPromptTemplate, "metric": evaluator.prompt_metrics, "steps": evaluator.prompt_steps, diff --git a/python/qianfan/resources/__init__.py b/python/qianfan/resources/__init__.py index e2eb0098..38ba92b1 100644 --- a/python/qianfan/resources/__init__.py +++ b/python/qianfan/resources/__init__.py @@ -14,6 +14,7 @@ from qianfan.resources.console.charge import Charge from qianfan.resources.console.data import Data from qianfan.resources.console.finetune import FineTune +from qianfan.resources.console.iam import IAM from qianfan.resources.console.memory import Memory from qianfan.resources.console.model import Model from qianfan.resources.console.prompt import Prompt @@ -49,4 +50,5 @@ "QfMessages", "QfResponse", "Memory", + "IAM", ] diff --git a/python/qianfan/resources/auth/oauth.py b/python/qianfan/resources/auth/oauth.py index 7999e9c3..6d0197be 100644 --- a/python/qianfan/resources/auth/oauth.py +++ b/python/qianfan/resources/auth/oauth.py @@ -14,7 +14,7 @@ import threading import time -from typing import Any, Dict, Optional, Tuple +from typing import Any, Callable, Dict, Optional, Tuple from qianfan.config import get_config from qianfan.consts import Consts @@ -42,26 +42,44 @@ class AuthManager(metaclass=Singleton): AuthManager is singleton to manage all access token in SDK """ - class AccessToken: + class Token: """ - Access Token object + Token object """ token: Optional[str] lock: threading.Lock alock: AsyncLock refresh_at: float - - def __init__(self, access_token: Optional[str] = None): + expire_at: float + refresh_func: Optional[Callable[..., Dict]] + """ + custom refresh function + return { + "token": "", + "expire_at": 0, # optional + "refresh_at": 0, # optional + } + """ + + def __init__( + self, + access_token: Optional[str] = None, + refresh_at: float = 0, + expire_at: float = 0, + refresh_func: Optional[Callable[..., Dict]] = None, + ): """ Init access token object """ self.token = access_token self.lock = threading.Lock() self.alock = AsyncLock() - self.refresh_at = 0 + self.refresh_at = refresh_at + self.expire_at = expire_at + self.refresh_func = refresh_func - _token_map: Dict[Tuple[str, str], AccessToken] + _token_map: Dict[Tuple[str, str], Token] def __init__(self) -> None: """ @@ -72,7 +90,13 @@ def __init__(self) -> None: self._lock = threading.Lock() self._alock = AsyncLock() - def _register(self, ak: str, sk: str, access_token: Optional[str] = None) -> bool: + def _register( + self, + ak: str, + sk: str, + access_token: Optional[str] = None, + refresh_func: Optional[Callable[..., Dict]] = None, + ) -> bool: """ add `(ak, sk)` to manager and return whether provided `(ak, sk)` is existed this function is not thread safe !!! @@ -80,7 +104,9 @@ def _register(self, ak: str, sk: str, access_token: Optional[str] = None) -> boo existed = True if (ak, sk) not in self._token_map: - self._token_map[(ak, sk)] = AuthManager.AccessToken(access_token) + self._token_map[(ak, sk)] = AuthManager.Token( + access_token, refresh_func=refresh_func + ) existed = False else: # if user provide new access token for existed (ak, sk), update it @@ -89,33 +115,43 @@ def _register(self, ak: str, sk: str, access_token: Optional[str] = None) -> boo self._token_map[(ak, sk)].refresh_at = 0 return existed - def register(self, ak: str, sk: str, access_token: Optional[str] = None) -> None: + def register( + self, + ak: str, + sk: str, + access_token: Optional[str] = None, + refresh_func: Optional[Callable[..., Dict]] = None, + ) -> None: """ add `(ak, sk)` to manager and update access token """ with self._lock: - existed = self._register(ak, sk, access_token) + existed = self._register(ak, sk, access_token, refresh_func) if not existed and access_token is None: - self.refresh_access_token(ak, sk) + self.refresh_token(ak, sk) async def aregister( - self, ak: str, sk: str, access_token: Optional[str] = None + self, + ak: str, + sk: str, + access_token: Optional[str] = None, + refresh_func: Optional[Callable[..., Dict]] = None, ) -> None: """ async add `(ak, sk)` to manager and update access token """ async with self._alock: - existed = self._register(ak, sk, access_token) + existed = self._register(ak, sk, access_token, refresh_func) if not existed and access_token is None: await self.arefresh_access_token(ak, sk) - def _get_access_token_object( + def _get_token_object( self, ak: str, sk: str - ) -> AccessToken: # pylint:disable=undefined-variable + ) -> Token: # pylint:disable=undefined-variable """ - get access token object by `(ak, sk)` + get token object by `(ak, sk)` this function is not thread safe !!! """ obj = self._token_map.get((ak, sk), None) @@ -123,11 +159,11 @@ def _get_access_token_object( raise InternalError("provided ak and sk are not registered") return obj - def _get_token_from_access_token_object( - self, obj: AccessToken, ak: str = "", sk: str = "" + def _get_token_from_token_object( + self, obj: Token, ak: str = "", sk: str = "" ) -> str: """ - get access token from access token object + get token from token object this function is not thread safe and should be protected by lock from obj !!! """ if obj.token is None: @@ -135,23 +171,26 @@ def _get_token_from_access_token_object( return "" return obj.token - def get_access_token(self, ak: str, sk: str) -> str: + def get_token(self, ak: str, sk: str) -> str: """ - get access token by `(ak, sk)` + get token by `(ak, sk)` """ with self._lock: - obj = self._get_access_token_object(ak, sk) + obj = self._get_token_object(ak, sk) + # 提前刷新 + if obj.expire_at != 0 and obj.expire_at - 10 < time.time(): + self.refresh_token(ak, sk) with obj.lock: - return self._get_token_from_access_token_object(obj, ak, sk) + return self._get_token_from_token_object(obj, ak, sk) - async def aget_access_token(self, ak: str, sk: str) -> str: + async def aget_token(self, ak: str, sk: str) -> str: """ - async get access token by `(ak, sk)` + async get token by `(ak, sk)` """ async with self._alock: - obj = self._get_access_token_object(ak, sk) + obj = self._get_token_object(ak, sk) async with obj.alock: - return self._get_token_from_access_token_object(obj, ak, sk) + return self._get_token_from_token_object(obj, ak, sk) def _auth_request(self, ak: str, sk: str) -> QfRequest: """ @@ -169,7 +208,7 @@ def _auth_request(self, ak: str, sk: str) -> QfRequest: ) def _update_access_token( - self, obj: AccessToken, response: Dict[str, Any], ak: str = "", sk: str = "" + self, obj: Token, response: Dict[str, Any], ak: str = "", sk: str = "" ) -> None: """ update access token from response of auth request @@ -184,15 +223,15 @@ def _update_access_token( " https://console.bce.baidu.com/qianfan/ais/console/applicationConsole/application" ) err_msg = response.get("error_description", "AK/SK is not correct") - chineses_err_msg = response.get("error_description", "AK/SK 错误") + chinese_err_msg = response.get("error_description", "AK/SK 错误") if err_msg == "unknown client id": err_msg = f"AK({_masked_ak(ak)}) is not correct" - chineses_err_msg = f"AK(`{_masked_ak(ak)}`) 错误" + chinese_err_msg = f"AK(`{_masked_ak(ak)}`) 错误" if err_msg == "Client authentication failed": err_msg = f"SK({_masked_ak(sk)}) is not correct" - chineses_err_msg = f"SK(`{_masked_ak(sk)}`) 错误" + chinese_err_msg = f"SK(`{_masked_ak(sk)}`) 错误" err_msg = exception_msg_tmpl.format( - err_msg=err_msg, chinese_err_msg=chineses_err_msg + err_msg=err_msg, chinese_err_msg=chinese_err_msg ) log_error(err_msg) raise AuthError(err_msg) @@ -203,62 +242,96 @@ def _update_access_token( ) ) return - obj.token = response["access_token"] - obj.refresh_at = time.time() + self._update_token(obj, response["access_token"]) + + def _update_token( + self, + obj: Token, + token: str, + refresh_at: float = 0, + expire_at: float = 0, + ) -> None: + obj.token = token + obj.refresh_at = refresh_at or time.time() + if expire_at != 0: + obj.expire_at = expire_at - def _refresh_access_token_too_often(self, obj: AccessToken) -> bool: + def _refresh_token_too_often(self, obj: Token) -> bool: """ - check if access token is refreshed too often + check if token is refreshed too often """ if ( time.time() - obj.refresh_at < get_config().ACCESS_TOKEN_REFRESH_MIN_INTERVAL ): - log_info("access_token is already refreshed, skip refresh.") + log_info("token is already refreshed, skip refresh.") return True return False - def refresh_access_token(self, ak: str, sk: str) -> None: - """ - refresh access token of `(ak, sk)` - """ + def refresh_token(self, ak: str, sk: str) -> None: with self._lock: - obj = self._get_access_token_object(ak, sk) + obj = self._get_token_object(ak, sk) with obj.lock: - log_info(f"trying to refresh access_token for ak `{_masked_ak(ak)}`") + log_info(f"trying to refresh token for ak `{_masked_ak(ak)}`") # in case multiple threads try to refresh access token at the same time # the token should not be refreshed multiple times - if self._refresh_access_token_too_often(obj): + if self._refresh_token_too_often(obj): return try: - resp = self._client.request(self._auth_request(ak, sk)) - json_body = resp.json() - self._update_access_token(obj, json_body, ak, sk) + if obj.refresh_func: + # bearer token + token_resp = obj.refresh_func(ak, sk) + token = token_resp.get("token", "") + self._update_token( + obj, + token=token, + refresh_at=token_resp.get("refresh_at", 0), + expire_at=token_resp.get("expire_at", 0), + ) + else: + # access token + resp = self._client.request(self._auth_request(ak, sk)) + json_body = resp.json() + self._update_access_token(obj, json_body, ak, sk) except AuthError: raise except Exception as e: - log_error(f"refresh access token failed with exception {str(e)}") - return + log_error(f"refresh token failed with exception {str(e)}") + raise e - log_info("sucessfully refresh access_token") + log_info("successfully refresh token") async def arefresh_access_token(self, ak: str, sk: str) -> None: """ async refresh access token of `(ak, sk)` """ async with self._alock: - obj = self._get_access_token_object(ak, sk) + obj = self._get_token_object(ak, sk) async with obj.alock: log_info(f"trying to refresh access_token for ak `{_masked_ak(ak)}`") # in case multiple threads try to refresh access token at the same time # the token should not be refreshed multiple times - if self._refresh_access_token_too_often(obj): + if self._refresh_token_too_often(obj): return try: - resp, session = await self._client.arequest(self._auth_request(ak, sk)) - async with session: - json_body = await resp.json() - self._update_access_token(obj, json_body, ak, sk) + if obj.refresh_func: + # bearer token + # TODO: implement bearer token refresh async + token_resp = obj.refresh_func(ak, sk) + token = token_resp.get("token", "") + self._update_token( + obj, + token=token, + refresh_at=token_resp.get("refresh_at", 0), + expire_at=token_resp.get("expire_at", 0), + ) + else: + resp, session = await self._client.arequest( + self._auth_request(ak, sk) + ) + async with session: + json_body = await resp.json() + self._update_access_token(obj, json_body, ak, sk) except AuthError: raise except Exception as e: @@ -270,23 +343,22 @@ async def arefresh_access_token(self, ak: str, sk: str) -> None: class Auth(object): """ - object to maintain acccess token for open api call + object to maintain authorization info for open api call + including access_token, ak/sk and access_key/secret_key """ _ak: Optional[str] = None _sk: Optional[str] = None _access_token: Optional[str] = None + _bearer_token: Optional[str] = None _access_key: Optional[str] = None _secret_key: Optional[str] = None _registered: bool = False - _console_ak_to_app_ak: Dict[Tuple[str, str], Tuple[str, str]] = {} - """ - (access_key, secret_key) -> (ak, sk) - map which convert console ak/sk to qianfan ak/sk - use as cache to avoid querying console ak/sk multple times - """ + _refresh_func: Optional[Callable[..., Dict]] = None - def __init__(self, **kwargs: Any) -> None: + def __init__( + self, refresh_func: Optional[Callable[..., Dict]] = None, **kwargs: Any + ) -> None: """ recv `ak`, `sk` and `access_token` from kwargs if the args does not contain the arguments, env variable will be used @@ -300,17 +372,26 @@ def __init__(self, **kwargs: Any) -> None: self._access_token = ( kwargs.get("access_token", None) or get_config().ACCESS_TOKEN ) + self._bearer_token = ( + kwargs.get("bearer_token", None) or get_config().BEARER_TOKEN + ) self._access_key = kwargs.get("access_key", None) or get_config().ACCESS_KEY self._secret_key = kwargs.get("secret_key", None) or get_config().SECRET_KEY + self._refresh_func = refresh_func if not self._credential_available() and not get_config().NO_AUTH: raise InvalidArgumentError( - "no enough credential found, any one of (access_key, secret_key)," - " (ak, sk), access_token must be provided" + "no enough credential found, use any one of (access_key, secret_key)," + " (ak, sk), (access_token) in api v1 or" + " any of (ak, sk), (bearer token) in api v2" ) if ( self._access_token is None and (self._ak is None or self._sk is None) - and (self._access_key is not None and self._secret_key is not None) + and ( + self._access_key is not None + and self._secret_key is not None + and refresh_func is None + ) ): self._registered = True @@ -319,7 +400,16 @@ def _register(self) -> None: register the access token to manager, so that it can be refreshed automatically """ if not self._registered: - if self._access_token is None: + if self._access_key and self._secret_key and self._refresh_func: + AuthManager().register( + self._access_key, + self._secret_key, + self._bearer_token, + self._refresh_func, + ) + elif self._bearer_token is not None and self._refresh_func: + self._registered = True + elif self._access_token is None: # if access_token is not provided, both ak and sk should be provided if self._ak is None or self._sk is None: raise InvalidArgumentError( @@ -339,6 +429,15 @@ async def _aregister(self) -> None: register the access token to manager, so that it can be refreshed automatically """ if not self._registered: + if self._access_key and self._secret_key and self._refresh_func: + await AuthManager().aregister( + self._access_key, + self._secret_key, + self._bearer_token, + self._refresh_func, + ) + elif self._bearer_token is not None and self._refresh_func: + self._registered = True if self._access_token is None: # if access_token is not provided, both ak and sk should be provided if self._ak is None or self._sk is None: @@ -356,6 +455,20 @@ async def _aregister(self) -> None: ) self._registered = True + def refresh_bearer_token(self) -> None: + """ + refresh `bearer_token` using `access_key` and `secret_key` + """ + if self._access_key is None or self._secret_key is None: + log_warn( + "access_key or secret_key is not set, refresh bearer_token will not" + " work." + ) + return + self._register() + AuthManager().refresh_token(self._access_key, self._secret_key) + self._bearer_token = None + def refresh_access_token(self) -> None: """ refresh `access_token` using `ak` and `sk` @@ -364,7 +477,7 @@ def refresh_access_token(self) -> None: log_warn("AK or SK is not set, refresh access_token will not work.") return self._register() - AuthManager().refresh_access_token(self._ak, self._sk) + AuthManager().refresh_token(self._ak, self._sk) self._access_token = None async def arefresh_access_token(self) -> None: @@ -379,12 +492,27 @@ async def arefresh_access_token(self) -> None: self._access_token = None def _credential_available(self) -> bool: + if self._refresh_func: + if self._bearer_token is not None: + return True + elif self._access_key is not None and self._secret_key is not None: + return True + else: + log_warn( + "no enough credential found, any one of (access_key, secret_key)," + " (bearer_token) must be provided" + ) + return False if self._access_token is not None: return True if self._ak is not None and self._sk is not None: return True if self._access_key is not None and self._secret_key is not None: return True + log_warn( + "no enough credential found, any one of (access_key, secret_key)," + " (ak, sk), (access_token) must be provided" + ) return False def access_token(self) -> str: @@ -398,7 +526,7 @@ def access_token(self) -> str: # use access_key and secret_key to auth # so no access_token here return "" - return AuthManager().get_access_token(self._ak, self._sk) + return AuthManager().get_token(self._ak, self._sk) async def a_access_token(self) -> str: """ @@ -411,4 +539,30 @@ async def a_access_token(self) -> str: # use access_key and secret_key to auth # so no access_token here return "" - return await AuthManager().aget_access_token(self._ak, self._sk) + return await AuthManager().aget_token(self._ak, self._sk) + + def bearer_token(self) -> str: + """ + get current `bearer_token` + """ + if self._bearer_token is not None and ( + self._access_key is None or self._secret_key is None + ): + return self._bearer_token + self._register() + if self._access_key is not None and self._secret_key is not None: + return AuthManager().get_token(self._access_key, self._secret_key) + return "" + + async def a_bearer_token(self) -> str: + """ + async get current `bearer_token` + """ + if self._bearer_token is not None and ( + self._access_key is None or self._secret_key is None + ): + return self._bearer_token + await self._aregister() + if self._access_key is not None and self._secret_key is not None: + return await AuthManager().aget_token(self._access_key, self._secret_key) + return "" diff --git a/python/qianfan/resources/batch_inference/helper/helper.py b/python/qianfan/resources/batch_inference/helper/helper.py index b4af0a7b..50310924 100644 --- a/python/qianfan/resources/batch_inference/helper/helper.py +++ b/python/qianfan/resources/batch_inference/helper/helper.py @@ -50,6 +50,9 @@ def get(self, remote_path: str, local_path: str) -> str: def rmr(self, *params: Any) -> str: return self._exec("rmr", *params) + def rm(self, path: str) -> str: + return self._exec("rm", path) + def _get_exec_cmd(self, cmd: str, *params: Any) -> str: log_debug(f"run cmd {cmd} {params}") exec_cmd = ( @@ -85,6 +88,16 @@ def mkdir(self, *params: Any) -> str: def cp(self, *params: Any) -> str: return self._exec("cp", *params) + def get_modify_time(self, path: str, *args: Any, **kwargs: Any) -> str: + return self._exec("stat", "%y", path) + + def test(self, path: str, *args: Any, **kwargs: Any) -> int: + try: + self._exec("test", *[*args, path]) + return 0 + except ValueError: + return 1 + def call_bf(afs_config: dict, **kwargs: Any) -> QfResponse: if not kwargs.get("retry_count"): diff --git a/python/qianfan/resources/console/consts.py b/python/qianfan/resources/console/consts.py index 9bc20ba5..798e8711 100644 --- a/python/qianfan/resources/console/consts.py +++ b/python/qianfan/resources/console/consts.py @@ -392,3 +392,18 @@ class ModelCompTaskStatus(str, Enum): """终止中""" Stopped: str = "stopped" """已终止""" + + +class V2: + class DatasetFormat(str, Enum): + PromptResponse: str = "PromptResponse" + Text: str = "Text" + DPOPromptChosenRejected: str = "DPO_PromptChosenRejected" + KTOPromptChosenRejected: str = "KTO_PromptChosenRejected" + PromptSortedResponses: str = "PromptSortedresponses" + Prompt: str = "Prompt" + PromptImage: str = "PromptImage" + + class StorageType(str, Enum): + Bos: str = "BOS" + SysStorage: str = "sysStorage" diff --git a/python/qianfan/resources/console/data.py b/python/qianfan/resources/console/data.py index 254c43b4..2107d6d3 100644 --- a/python/qianfan/resources/console/data.py +++ b/python/qianfan/resources/console/data.py @@ -20,6 +20,9 @@ from qianfan.consts import Consts from qianfan.errors import QianfanError +from qianfan.resources.console.consts import ( + V2 as V2Consts, +) from qianfan.resources.console.consts import ( DataExportDestinationType, DataProjectType, @@ -1012,3 +1015,494 @@ def list_offline_batch_inference_task( if v is not None } return req + + class V2: + @classmethod + def base_api_route(cls) -> str: + """ + base api url route for dataset V2. + + Returns: + str: base api url route + """ + return Consts.DatasetV2BaseRouteAPI + + @classmethod + @console_api_request + def create_dataset( + cls, + dataset_name: str, + dataset_format: V2Consts.DatasetFormat, + storage_type: V2Consts.StorageType, + storage_path: Optional[str] = None, + **kwargs: Any, + ) -> QfRequest: + """ + create dataset. + + Parameters: + dataset_name (str): + dataset name. + dataset_format (V2Consts.DatasetFormat): + dataset format. + storage_type (V2Consts.StorageType): + storage type. + storage_path (Optional[str], optional): + storage path. Defaults to None. + + Note: + The `@console_api_request` decorator is applied to this method, + enabling it to send the generated QfRequest + and return a QfResponse to the user. + """ + post_body_dict = { + "datasetName": dataset_name, + "dataFormat": dataset_format.value, + "storageType": storage_type.value, + } + + if storage_path: + if storage_path[-1] != "/": + storage_path += "/" + + post_body_dict["storagePath"] = storage_path + + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query(Consts.DatasetV2CreateDatasetAction), + ) + req.json_body = post_body_dict + + return req + + @classmethod + @console_api_request + def get_dataset_list( + cls, + marker: Optional[str] = None, + max_keys: int = 10, + page_reverse: bool = False, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> QfRequest: + """ + get dataset list. + + Parameters: + marker (Optional[str], optional): + marker of the first page. Defaults to None. + max_keys (int, optional): + max keys of the page. Defaults to 10. + page_reverse (bool, optional): + page reverse or not. Defaults to False. + filter (Optional[Dict[str, Any]], optional): + filter conditions. Defaults to None. + + Note: + The `@console_api_request` decorator is applied to this method, + enabling it to send the generated QfRequest + and return a QfResponse to the user. + """ + post_json_body: Dict[str, Any] = { + "maxKeys": max_keys, + "pageReverse": page_reverse, + } + + if marker: + post_json_body["marker"] = marker + + if filter: + post_json_body["filter"] = filter + + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query(Consts.DatasetV2GetDatasetListAction), + ) + req.json_body = post_json_body + + return req + + @classmethod + @console_api_request + def delete_dataset( + cls, + dataset_id: str, + **kwargs: Any, + ) -> QfRequest: + """ + delete dataset. + + Parameters: + dataset_id (str): + dataset id. + + Note: + The `@console_api_request` decorator is applied to this method, + enabling it to send the generated QfRequest + and return a QfResponse to the user. + """ + + post_body_dict = { + "datasetId": dataset_id, + } + + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query(Consts.DatasetV2DeleteDatasetAction), + ) + req.json_body = post_body_dict + + return req + + @classmethod + @console_api_request + def create_dataset_version( + cls, + dataset_id: str, + description: Optional[str] = None, + **kwargs: Any, + ) -> QfRequest: + """ + create dataset version. + + Parameters: + dataset_id (str): + dataset id. + description (Optional[str], optional): + dataset version description. Defaults to None. + + Note: + The `@console_api_request` decorator is applied to this method, + enabling it to send the generated QfRequest + and return a QfResponse to the user. + """ + post_body_dict = { + "datasetId": dataset_id, + } + + if description: + post_body_dict["description"] = description + + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query(Consts.DatasetV2CreateDatasetVersionAction), + ) + req.json_body = post_body_dict + + return req + + @classmethod + @console_api_request + def get_dataset_version_info( + cls, + version_id: str, + **kwargs: Any, + ) -> QfRequest: + """ + get dataset version info. + + Parameters: + version_id (str): + dataset version id. + + Note: + The `@console_api_request` decorator is applied to this method, + enabling it to send the generated QfRequest + and return a QfResponse to the user. + """ + + post_body_dict = { + "versionId": version_id, + } + + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query( + Consts.DatasetV2GetDatasetVersionInfoAction + ), + ) + req.json_body = post_body_dict + + return req + + @classmethod + @console_api_request + def delete_dataset_version( + cls, + version_id: str, + **kwargs: Any, + ) -> QfRequest: + """ + delete dataset version. + + Parameters: + version_id (str): + dataset version id. + + Note: + The `@console_api_request` decorator is applied to this method, + enabling it to send the generated QfRequest + and return a QfResponse to the user. + """ + + post_body_dict = { + "versionId": version_id, + } + + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query(Consts.DatasetV2DeleteDatasetVersionAction), + ) + req.json_body = post_body_dict + + return req + + @classmethod + @console_api_request + def publish_dataset_version( + cls, + version_id: str, + **kwargs: Any, + ) -> QfRequest: + """ + publish dataset version. + + Parameters: + version_id (str): + dataset version id. + + Note: + The `@console_api_request` decorator is applied to this method, + enabling it to send the generated QfRequest + and return a QfResponse to the user. + """ + + post_body_dict = { + "versionId": version_id, + } + + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query( + Consts.DatasetV2PublishDatasetVersionAction + ), + ) + req.json_body = post_body_dict + + return req + + @classmethod + @console_api_request + def get_dataset_version_list( + cls, + dataset_id: str, + marker: Optional[str] = None, + max_keys: int = 10, + page_reverse: bool = False, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> QfRequest: + """ + get dataset version list. + + Parameters: + dataset_id (str): + dataset id. + marker (Optional[str], optional): + marker. Defaults to None. + max_keys (int, optional): + max keys. Defaults to 10. + page_reverse (bool, optional): + page reverse. Defaults to False. + filter (Optional[Dict[str, Any]], optional): + filter. Defaults to None. + + Note: + The `@console_api_request` decorator is applied to this method, + enabling it to send the generated QfRequest + and return a QfResponse to the user. + """ + + post_body_dict = { + "datasetId": dataset_id, + "maxKeys": max_keys, + "pageReverse": page_reverse, + } + + if marker: + post_body_dict["marker"] = marker + + if filter: + post_body_dict["filter"] = filter + + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query( + Consts.DatasetV2GetDatasetVersionListAction + ), + ) + req.json_body = post_body_dict + + return req + + @classmethod + @console_api_request + def create_dataset_version_import_task( + cls, + version_id: str, + files: List[str], + **kwargs: Any, + ) -> QfRequest: + """ + create dataset version import task. + + Parameters: + version_id (str): + dataset version id. + files (List[str]): + file paths. + + Note: + The `@console_api_request` decorator is applied to this method, + enabling it to send the generated QfRequest + and return a QfResponse to the user. + """ + + post_body_dict = { + "versionId": version_id, + "files": files, + } + + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query( + Consts.DatasetV2CreateDatasetVersionImportTaskAction + ), + ) + req.json_body = post_body_dict + + return req + + @classmethod + @console_api_request + def get_dataset_version_import_task_info( + cls, + task_id: str, + **kwargs: Any, + ) -> QfRequest: + """ + get dataset version import task info. + + Parameters: + task_id (str): + task id. + + Note: + The `@console_api_request` decorator is applied to this method, + enabling it to send the generated QfRequest + and return a QfResponse to the user. + """ + + post_body_dict = { + "taskId": task_id, + } + + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query( + Consts.DatasetV2GetDatasetVersionImportTaskInfoAction + ), + ) + req.json_body = post_body_dict + + return req + + @classmethod + @console_api_request + def create_dataset_version_export_task( + cls, + version_id: str, + storage_type: V2Consts.StorageType, + storage_path: Optional[str] = None, + **kwargs: Any, + ) -> QfRequest: + """ + create dataset version export task. + + Parameters: + version_id (str): + dataset version id. + storage_type (V2Consts.StorageType): + storage type. + storage_path (Optional[str], optional): + storage path. Defaults to None. + + Note: + The `@console_api_request` decorator is applied to this method, + enabling it to send the generated QfRequest + and return a QfResponse to the user. + """ + + post_body_dict = { + "versionId": version_id, + "storageType": storage_type.value, + } + + if storage_path: + post_body_dict["storagePath"] = storage_path + + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query( + Consts.DatasetV2CreateDatasetVersionExportTaskAction + ), + ) + req.json_body = post_body_dict + + return req + + @classmethod + @console_api_request + def get_dataset_version_export_task_info( + cls, + task_id: str, + **kwargs: Any, + ) -> QfRequest: + """ + get dataset version export task info. + + Parameters: + task_id (str): + task id. + + Note: + The `@console_api_request` decorator is applied to this method, + enabling it to send the generated QfRequest + and return a QfResponse to the user. + """ + + post_body_dict = { + "taskId": task_id, + } + + req = QfRequest( + method="POST", + url=cls.base_api_route(), + query=_get_console_v2_query( + Consts.DatasetV2GetDatasetVersionExportTaskInfoAction + ), + ) + req.json_body = post_body_dict + + return req diff --git a/python/qianfan/resources/console/iam.py b/python/qianfan/resources/console/iam.py new file mode 100644 index 00000000..6e935235 --- /dev/null +++ b/python/qianfan/resources/console/iam.py @@ -0,0 +1,63 @@ +# Copyright (c) 2023 Baidu, Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +iam API +""" +from typing import Any + +from qianfan.config import get_config +from qianfan.consts import Consts +from qianfan.resources.console.utils import console_api_request +from qianfan.resources.typing import QfRequest, QfResponse + + +class IAM(object): + @classmethod + def create_bearer_token( + cls, + expire_in_seconds: int = 1000, + **kwargs: Any, + ) -> QfResponse: + """ + create a bearer token for call api v2. + + Parameters: + expire_in_seconds (int): + expire time of the token, in seconds. + kwargs: + Additional keyword arguments that can be passed to customize the + request. + """ + kwargs["host"] = get_config().IAM_BASE_URL + return cls._iam_call( + req=QfRequest( + method="GET", + url=Consts.IAMBearerTokenAPI, + query={ + "expireInSeconds": str(expire_in_seconds), + }, + ), + **kwargs, + ) + + @classmethod + @console_api_request + def _iam_call(cls, req: QfRequest, **kwargs: Any) -> QfRequest: + """ + inner caller for iam api, which accept a new host for iam api. + The `@console_api_request` decorator is applied to this method, enabling + it to send the generated QfRequest and return a QfResponse to the user. + """ + return req diff --git a/python/qianfan/resources/http_client.py b/python/qianfan/resources/http_client.py index 35d3f977..cbeb92e3 100644 --- a/python/qianfan/resources/http_client.py +++ b/python/qianfan/resources/http_client.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, AsyncIterator, Dict, Iterator, Optional, Tuple +from typing import Any, AsyncGenerator, Dict, Iterator, Optional, Tuple import aiohttp import requests @@ -111,7 +111,7 @@ async def arequest( async def arequest_stream( self, req: QfRequest - ) -> AsyncIterator[Tuple[bytes, aiohttp.ClientResponse]]: + ) -> AsyncGenerator[Tuple[bytes, aiohttp.ClientResponse], None]: """ async stream request """ diff --git a/python/qianfan/resources/llm/base.py b/python/qianfan/resources/llm/base.py index 3a41a4af..541cb470 100644 --- a/python/qianfan/resources/llm/base.py +++ b/python/qianfan/resources/llm/base.py @@ -51,7 +51,7 @@ QfResponse, RetryConfig, ) -from qianfan.utils import log_info, log_warn, utils +from qianfan.utils import log_debug, log_info, log_warn, utils from qianfan.utils.cache.base import KvCache from qianfan.version import VERSION @@ -164,7 +164,14 @@ def __init__( ) -> None: self._version = str(version) if version else "1" self._real = self._real_base(self._version, **kwargs)(**kwargs) - self._backup = self._real_base("1", **kwargs)(**kwargs) + if self._version != "1": + try: + self._backup = self._real_base("1", **kwargs)(**kwargs) + except Exception as e: + log_debug( + f"Failed to create V1 backup instance, error: {e}, " + "will use the latest version instead." + ) @classmethod def _real_base(cls, version: str, **kwargs: Any) -> Type[BaseResource]: @@ -214,7 +221,7 @@ def _need_downgrade(self) -> bool: return get_config().V2_INFER_API_DOWNGRADE and self._version != "1" def _do(self, **kwargs: Any) -> Union[QfResponse, Iterator[QfResponse]]: - if self._need_downgrade(): + if self._need_downgrade() and self._backup: return self._do_downgrade(**kwargs) # assert self._real has function `do` return self._real.do(**kwargs) # type: ignore @@ -612,6 +619,7 @@ def __init__( self, model: Optional[str] = None, endpoint: Optional[str] = None, + use_custom_endpoint: bool = False, **kwargs: Any, ) -> None: """ @@ -620,6 +628,7 @@ def __init__( self._model = model self._endpoint = endpoint self._client = create_api_requestor(**kwargs) + self.use_custom_endpoint = use_custom_endpoint def _update_model_and_endpoint( self, model: Optional[str], endpoint: Optional[str] @@ -644,6 +653,9 @@ def _update_model_and_endpoint( ) endpoint = model_info.endpoint else: + # 适配非公有云等需要增加chat/等前缀的endpoint + if self.use_custom_endpoint or get_config().USE_CUSTOM_ENDPOINT: + return model, endpoint endpoint = self._convert_endpoint(model, endpoint) return model, endpoint @@ -905,9 +917,17 @@ def get_model_info(cls, model: str) -> QfLLMInfo: class BaseResourceV2(BaseResource): - def __init__(self, model: Optional[str] = None, **kwargs: Any) -> None: + def __init__( + self, + model: Optional[str] = None, + app_id: Optional[str] = None, + bearer_token: Optional[str] = None, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) self._model = model + self._app_id = app_id + self._bearer_token = bearer_token self._client = QfAPIV2Requestor(**kwargs) def _request( @@ -938,6 +958,18 @@ def _request( return resp + def _generate_header( + self, model: Optional[str], stream: bool, **kwargs: Any + ) -> JsonBody: + """ + generate header + """ + + base_headers = super()._generate_header(model, stream, **kwargs) + if self._app_id and "app_id" not in base_headers: + base_headers["appid"] = self._app_id + return base_headers + async def _arequest( self, model: Optional[str], @@ -1031,13 +1063,8 @@ def get_latest_supported_models( # get preset services: for s in svc_list: try: - [api_type, model_endpoint] = trim_prefix( - s["url"], - "{}{}/".format( - DefaultValue.BaseURL, - get_config().MODEL_API_PREFIX, - ), - ).split("/") + splits = s["url"].split("/") + api_type, model_endpoint = splits[-2], splits[-1] model_info = _runtime_models_info.get(api_type) if model_info is None: model_info = {} diff --git a/python/qianfan/resources/llm/chat_completion.py b/python/qianfan/resources/llm/chat_completion.py index 6e6df695..11e1163d 100644 --- a/python/qianfan/resources/llm/chat_completion.py +++ b/python/qianfan/resources/llm/chat_completion.py @@ -37,7 +37,7 @@ from qianfan.resources.llm.function import Function from qianfan.resources.tools.tokenizer import Tokenizer from qianfan.resources.typing import JsonBody, QfLLMInfo, QfMessages, QfResponse, QfRole -from qianfan.utils.logging import log_error, log_info +from qianfan.utils.logging import log_error, log_info, log_warn class _ChatCompletionV1(BaseResourceV1): @@ -1738,6 +1738,17 @@ def batch_do( """ + if "enable_reading_buffer" in kwargs: + log_warn( + "enable_reading_buffer has been deprecated, please use" + " show_total_latency instead" + ) + if ( + isinstance(kwargs["enable_reading_buffer"], bool) + and kwargs["enable_reading_buffer"] + ): + show_total_latency = True + def worker( inner_func: Callable, **kwargs: Any ) -> Union[List[QfResponse], Iterator[QfResponse], QfResponse, Exception]: @@ -1805,6 +1816,18 @@ async def abatch_do( ``` """ + + if "enable_reading_buffer" in kwargs: + log_warn( + "enable_reading_buffer has been deprecated, please use" + " show_total_latency instead" + ) + if ( + isinstance(kwargs["enable_reading_buffer"], bool) + and kwargs["enable_reading_buffer"] + ): + show_total_latency = True + task_list: List[Callable] async def worker( diff --git a/python/qianfan/resources/requestor/base.py b/python/qianfan/resources/requestor/base.py index 01e877cf..87182b31 100644 --- a/python/qianfan/resources/requestor/base.py +++ b/python/qianfan/resources/requestor/base.py @@ -82,7 +82,7 @@ def _check_if_status_code_is_200(response: requests.Response) -> None: check whether the status code of response is ok(200) if the status code is not 200, raise a `RequestError` """ - if response.status_code != 200: + if response.status_code >= 300 or response.status_code < 200: failed_msg = ( f"http request url {response.url} failed with http status code" f" {response.status_code}\n" @@ -298,6 +298,7 @@ def __init__(self, **kwargs: Any) -> None: """ self._client = HTTPClient(**kwargs) self._rate_limiter = VersatileRateLimiter(**kwargs) + self._host = kwargs.get("host") def _preprocess_request(self, request: QfRequest) -> QfRequest: return request diff --git a/python/qianfan/resources/requestor/console_requestor.py b/python/qianfan/resources/requestor/console_requestor.py index 2dde976d..44fd0a22 100644 --- a/python/qianfan/resources/requestor/console_requestor.py +++ b/python/qianfan/resources/requestor/console_requestor.py @@ -17,7 +17,7 @@ """ from copy import deepcopy from typing import Any, Dict -from urllib.parse import urlparse +from urllib.parse import urlparse, urlunparse import qianfan.errors as errors from qianfan import get_config @@ -57,7 +57,7 @@ def _request_console_api( def _helper() -> QfResponse: req_copy = deepcopy(req) - ConsoleAPIRequestor._sign(req_copy, ak, sk) + self._sign(req_copy, ak, sk) return self._request(req_copy) return self._with_retry(retry_config, _helper) @@ -72,22 +72,29 @@ async def _async_request_console_api( req.retry_config = retry_config async def _helper() -> QfResponse: - ConsoleAPIRequestor._sign(req, ak, sk) + self._sign(req, ak, sk) return await self._async_request(req) return await self._async_with_retry(retry_config, _helper) - @staticmethod - def _sign(request: QfRequest, ak: str, sk: str) -> None: + def _sign(self, request: QfRequest, ak: str, sk: str) -> None: """ sign the request """ parsed_uri = urlparse(get_config().CONSOLE_API_BASE_URL) - host = parsed_uri.netloc + if self._host: + parsed_new = urlparse(self._host) + scheme = parsed_new.scheme if parsed_new.scheme else parsed_uri.scheme + netloc = parsed_new.netloc if parsed_new.netloc else parsed_uri.netloc + parsed_uri = parsed_uri._replace(scheme=scheme, netloc=netloc) + final_base_url = urlunparse(parsed_uri) + else: + final_base_url = get_config().CONSOLE_API_BASE_URL + request.headers = { "Content-Type": "application/json", - "Host": host, + "Host": parsed_uri.netloc, **request.headers, } iam_sign(ak, sk, request) - request.url = get_config().CONSOLE_API_BASE_URL + request.url + request.url = final_base_url + request.url diff --git a/python/qianfan/resources/requestor/openapi_requestor.py b/python/qianfan/resources/requestor/openapi_requestor.py index f2794da5..03ed6a68 100644 --- a/python/qianfan/resources/requestor/openapi_requestor.py +++ b/python/qianfan/resources/requestor/openapi_requestor.py @@ -15,10 +15,10 @@ """ Qianfan API Requestor """ - import copy import json import os +from datetime import datetime from typing import ( Any, AsyncIterator, @@ -207,7 +207,11 @@ async def _async_request_stream( if "json" in resp.headers.get("content-type", ""): body, _ = await responses.__anext__() - self._check_error(json.loads(body)) + try: + self._check_error(json.loads(body)) + except Exception as e: + await responses.aclose() + raise e async def iter() -> AsyncIterator[QfResponse]: nonlocal responses @@ -414,6 +418,7 @@ def llm( llm related api request """ log_debug(f"requesting llm api endpoint: {endpoint}") + # TODO 应该放在Adapter中做处理。 for m in body.get("messages", []): if m.get("role", "") == "function": if not m.get("name", None): @@ -590,7 +595,7 @@ def _sign(request: QfRequest, ak: str, sk: str) -> None: request.headers = { "Content-Type": "application/json", "Host": host, - **request.headers, + **(request.headers if request.headers else {}), } iam_sign(ak, sk, request) request.url = url @@ -679,6 +684,30 @@ async def _helper() -> QfResponse: class QfAPIV2Requestor(QfAPIRequestor): + def __init__(self, **kwargs: Any) -> None: + super().__init__(refresh_func=QfAPIV2Requestor._refresh_bearer_token, **kwargs) + + @staticmethod + def _refresh_bearer_token(*args: Any, **kwargs: Any) -> Dict: + """ + refresh bearer token + """ + from qianfan.resources.console.iam import IAM + + resp = IAM.create_bearer_token( + expire_in_seconds=get_config().BEARER_TOKEN_EXPIRED_INTERVAL + ) + + def _convert_time_str_to_sec(time_str: str) -> float: + time_obj = datetime.strptime(time_str, "%Y-%m-%dT%H:%M:%S.%fZ") + return time_obj.replace().timestamp() + + return { + "token": resp.body["token"], + "refresh_at": _convert_time_str_to_sec(resp.body["createTime"]), + "expire_at": _convert_time_str_to_sec(resp.body["expireTime"]), + } + def _llm_api_url(self, endpoint: str) -> str: """ convert endpoint to llm api url @@ -692,29 +721,41 @@ def _add_access_token( self, req: QfRequest, auth: Optional[Auth] = None ) -> QfRequest: """ - add access token to QfRequest + add bearer token to QfRequest V2 """ + if get_config().NO_AUTH: + # 配置无鉴权,不签名,不抛出需要刷新token的异常,直接跳出。 + return req if auth is None: auth = self._auth + bearer_token = auth.bearer_token() + if bearer_token == "": + raise errors.BearerTokenExpiredError + else: + # use openapi auth + req.headers["Authorization"] = f"Bearer {bearer_token}" + return req - # use IAM auth - access_key = auth._access_key - secret_key = auth._secret_key - if access_key is None or secret_key is None: - extra_msg = "" - if get_config().AK is not None or get_config().SK is not None: - extra_msg = ( - " AK and SK cannot be used in V2 API. V2 推理 API 不支持通过 AK 和" - " SK 进行鉴权,请换用 access_key 和 secret_key。" - ) - raise errors.InvalidArgumentError( - "access_key and secret_key must be provided! 未提供 access_key 或" - " secret_key!" - + extra_msg - ) - self._sign(req, access_key, secret_key) + def _retry_if_token_expired(self, func: Callable[..., _T]) -> Callable[..., _T]: + """ + this is a wrapper to deal with token expired error + """ + token_refreshed = False - return req + def retry_wrapper(*args: Any, **kwargs: Any) -> _T: + nonlocal token_refreshed + # if token is refreshed, token expired exception will not be dealt with + if not token_refreshed: + try: + return func(*args) + except errors.BearerTokenExpiredError: + # refresh token and set token_refreshed flag + self._auth.refresh_bearer_token() + token_refreshed = True + # then fallthrough and try again + return func(*args, **kwargs) + + return retry_wrapper async def _async_add_access_token( self, req: QfRequest, auth: Optional[Auth] = None @@ -745,6 +786,9 @@ def _check_error(self, body: Dict[str, Any]) -> None: raise errors.APIError(error_code, err_msg, req_id) + # TODO 加上过期的raise BearerTokenExpiredError + # 当前无法区分 + def create_api_requestor(*args: Any, **kwargs: Any) -> QfAPIRequestor: if get_config().ENABLE_PRIVATE: diff --git a/python/qianfan/tests/chat_completion_v2_test.py b/python/qianfan/tests/chat_completion_v2_test.py index b5632f05..5377eb30 100644 --- a/python/qianfan/tests/chat_completion_v2_test.py +++ b/python/qianfan/tests/chat_completion_v2_test.py @@ -17,6 +17,7 @@ """ import threading +import time import pytest @@ -24,6 +25,9 @@ import qianfan.tests.utils from qianfan.consts import Consts +TEST_BEARER_TOKEN = ( + "bce-v3/ALTAK-JZasis7GfnokSLLXykKHj/054c2e64c06db4d6019f0dbfc964e90aa3fc3ddd" +) TEST_MODEL = "ernie-unit-test" TEST_MESSAGE = [ @@ -268,3 +272,47 @@ def test_in_other_thread(): t = threading.Thread(target=test_generate) t.start() t.join() + + +def test_auth_using_bearer_token(): + ak, sk, access_key, secret_key = ( + qianfan.get_config().AK, + qianfan.get_config().SK, + qianfan.get_config().ACCESS_KEY, + qianfan.get_config().SECRET_KEY, + ) + qianfan.get_config().AK = None + qianfan.get_config().SK = None + qianfan.get_config().ACCESS_KEY = None + qianfan.get_config().SECRET_KEY = None + qianfan.get_config().BEARER_TOKEN = TEST_BEARER_TOKEN + resp = qianfan.ChatCompletion(version="2").do(messages=TEST_MESSAGE[:1]) + assert resp.body.get("choices") is not None + qianfan.get_config().AK = ak + qianfan.get_config().SK = sk + qianfan.get_config().ACCESS_KEY = access_key + qianfan.get_config().SECRET_KEY = secret_key + qianfan.get_config().BEARER_TOKEN = None + + +def test_refresh_token(): + preset_interval = qianfan.get_config().BEARER_TOKEN_EXPIRED_INTERVAL + qianfan.get_config().BEARER_TOKEN_EXPIRED_INTERVAL = 5 + chat = qianfan.ChatCompletion(version=2, app_id="app-xxx") + + def call() -> qianfan.QfResponse: + resp = chat.do( + messages=[{"role": "user", "content": "你好"}], + model="xxxx", + preemptable=True, + top_p=0.5, + ) + return resp + + resp1 = call() + resp2 = call() + time.sleep(6) + assert ( + resp1.request.headers["Authorization"] == resp2.request.headers["Authorization"] + ) + qianfan.get_config().BEARER_TOKEN_EXPIRED_INTERVAL = preset_interval diff --git a/python/qianfan/tests/utils/mock_server.py b/python/qianfan/tests/utils/mock_server.py index 28ad61f0..33288910 100644 --- a/python/qianfan/tests/utils/mock_server.py +++ b/python/qianfan/tests/utils/mock_server.py @@ -24,6 +24,7 @@ import threading import time import zipfile +from datetime import datetime, timedelta, timezone from functools import wraps from io import BytesIO @@ -268,6 +269,31 @@ def wrapper(*args, **kwargs): return wrapper +def iam_v3_auth_checker(func): + """ + decorator for checking bearer token + """ + + @wraps(func) + def wrapper(*args, **kwargs): + """ + wrapper for function + """ + authorization = request.headers.get("authorization") + if not authorization.startswith("Bearer"): + return flask.Response( + status=403, + headers={ + "X-Bce-Error-Message": ( + "mock server error, authorization or bce_date not found" + ) + }, + ) + return func(*args, **kwargs) + + return wrapper + + def iam_auth_checker(func): """ decorator for checking access token @@ -622,8 +648,36 @@ def chat(model_name): ) +history_tokens = {} + + +@app.route(Consts.IAMBearerTokenAPI, methods=["GET"]) +def iam_get_bearer_token(): + expire_seconds = request.args.get("expireInSeconds") + current_time = datetime.now(timezone.utc) + # 加上 100 秒 + expire_time = current_time + timedelta(seconds=int(expire_seconds)) + + # 格式化为指定格式 + current_time_str = current_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + expire_time_str = expire_time.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] + "Z" + token = f"bce-v3/ALTAK-XNYpvSpWTC0Qr0cB6LoZR/{generate_letter_num_random_id(16)}" + global history_tokens + history_tokens[token] = expire_time + + return json_response( + { + "userId": "6c6093c96f0241c087af184cc5729de8", + "token": token, + "status": "enable", + "createTime": current_time_str, + "expireTime": expire_time_str, + }, + ) + + @app.route(Consts.ChatV2API, methods=["POST"]) -@iam_auth_checker +@iam_v3_auth_checker def chat_v2(): """ mock chat completion v2 api @@ -634,7 +688,6 @@ def chat_v2(): model_name = r["model"] if model_name.startswith("test_retry"): global retry_cnt - print("mock retry cnt", retry_cnt) if model_name not in retry_cnt: retry_cnt[model_name] = 1 if retry_cnt[model_name] % 3 != 0: diff --git a/python/qianfan/tests/utils/utils.py b/python/qianfan/tests/utils/utils.py index d100b0b6..c503b576 100644 --- a/python/qianfan/tests/utils/utils.py +++ b/python/qianfan/tests/utils/utils.py @@ -30,6 +30,7 @@ def init_test_env(): os.environ[Env.BaseURL] = "http://127.0.0.1:8866" os.environ[Env.ConsoleAPIBaseURL] = "http://127.0.0.1:8866" os.environ[Env.DisableErnieBotSDK] = "True" + os.environ[Env.IAMBaseURL] = "http://127.0.0.1:8866" qianfan.enable_log(logging.INFO) if "QIANFAN_AK" in os.environ: os.environ.pop("QIANFAN_AK") diff --git a/python/qianfan/utils/fake_pyarrow/__init__.py b/python/qianfan/utils/fake_pyarrow/__init__.py index f7f74256..6b42fd9d 100644 --- a/python/qianfan/utils/fake_pyarrow/__init__.py +++ b/python/qianfan/utils/fake_pyarrow/__init__.py @@ -17,9 +17,15 @@ from qianfan.utils.fake_pyarrow.functions import concat_tables from qianfan.utils.fake_pyarrow.table import ChunkedArray, Table +from qianfan.utils.logging import log_warn is_fake = True +log_warn( + "no pyarrow has been installed. Dataset will run in restrict mode in which some" + " functions may not be available" +) + __all__ = [ "Table", "concat_tables",