diff --git a/cookbook/RAG/baidu_elasticsearch/qianfan_baidu_elasticsearch.ipynb b/cookbook/RAG/baidu_elasticsearch/qianfan_baidu_elasticsearch.ipynb index af04f604..6536ec4f 100644 --- a/cookbook/RAG/baidu_elasticsearch/qianfan_baidu_elasticsearch.ipynb +++ b/cookbook/RAG/baidu_elasticsearch/qianfan_baidu_elasticsearch.ipynb @@ -12,6 +12,21 @@ "本文主要介绍基于Langchain的框架,结合BES的向量数据库的能力,对接千帆平台的模型管理和应用接入的能力,从而构建一个RAG的知识问答场景。" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "bes,1\n", + "faiss,1\n", + "chroma, 1\n", + "\n", + "postgresql\n", + "milvus, \n", + "pincone2" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -73,7 +88,7 @@ "\n", "loader = TextLoader(\"example_data/ai-paper.pdf\")\n", "documents = loader.load()\n", - "text_splitter = CharacterTextSplitter(chunk_size=768, chunk_overlap=0, separators=[\"\\n\\n\", \"\\n\", \" \", \"\", \"。\", \",\"])\n", + "text_splitter = CharacterTextSplitter(chunk_size=768, chunk_overlap=0, separators=[\"\\n\\n\", \"\\n\", \" \", \"\", \"。\", \",\"]) # spaciy\n", "docs = text_splitter.split_documents(documents)" ] }, @@ -92,9 +107,13 @@ "metadata": {}, "outputs": [], "source": [ - "from langchain.embeddings import QianfanEmbeddingsEndpoint\n", + "from langchain.embeddings import QianfanEmbeddingsEndpoint #sdk\n", "\n", - "embeddings = QianfanEmbeddingsEndpoint()" + "embeddings = QianfanEmbeddingsEndpoint()\n", + "# embeddings-v1\n", + "# bge-large-zh\n", + "# 12月2k\n", + "# " ] }, { @@ -123,6 +142,15 @@ "bes.client.indices.refresh(index=\"your vector index\")" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Faiss\n" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -141,6 +169,7 @@ "from langchain.chat_models import QianfanChatEndpoint\n", "\n", "qianfan_chat_model = QianfanChatEndpoint(model=\"ERNIE-Bot\")\n", + "# sdk prompt load from qianfan\n", "qa = RetrievalQA.from_chain_type(llm=llm, chain_type=\"refine\", retriever=retriever, return_source_documents=True)\n", "\n", "\n", @@ -171,18 +200,18 @@ ], "metadata": { "kernelspec": { - "display_name": "Python 3", + "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "name": "python", - "version": "3.9.17" + "version": "3.11.5" }, "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "aee8b7b246df8f9039afb4144a1f6fd8d2ca17a180786b69acc140d282b71a49" + "hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb" } } }, diff --git a/cookbook/eb_search.ipynb b/cookbook/eb_search.ipynb index 0842c665..768d3609 100644 --- a/cookbook/eb_search.ipynb +++ b/cookbook/eb_search.ipynb @@ -204,7 +204,7 @@ ], "metadata": { "kernelspec": { - "display_name": "py311", + "display_name": "base", "language": "python", "name": "python3" }, @@ -223,7 +223,7 @@ "orig_nbformat": 4, "vscode": { "interpreter": { - "hash": "f553a591cb5da27fa30e85168a93942a1a24c8d6748197473adb125e5473a5db" + "hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb" } } }, diff --git a/cookbook/finetune/trainer-finetune.ipynb b/cookbook/finetune/trainer-finetune.ipynb new file mode 100644 index 00000000..39d36acb --- /dev/null +++ b/cookbook/finetune/trainer-finetune.ipynb @@ -0,0 +1,268 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Trainer\n", + "千帆Python SDK 在使用[resource API实现发起训练微调](./console-finetune.ipynb)之外,还提供了Trainer API,可以更方便地实现一体化的训练微调pipeline。同时提供了状态事件回调函数的注册,通过事件分发实现训练流程状态事件的监控。\n", + "\n", + "\n", + "本例将基于qianfan==0.2.1展示通过Dataset加载本地数据集,并上传到千帆平台,基于LLama-2-7b进行fine-tune,直到最终完成服务发布,并最终实现服务调用的完整过程。" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 前置准备\n", + "- 初始化千帆安全认证AK、SK" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os \n", + "\n", + "os.environ[\"QIANFAN_ACCESS_KEY\"] = \"your_ak\"\n", + "os.environ[\"QIANFAN_SECRET_KEY\"] = \"your_sk\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## 数据集加载\n", + "\n", + "千帆SDK提供了数据集实现帮助我们可以快速的加载本地的数据集到内存,并通过设定DataSource数据源以保存至本地和千帆平台。" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[[{'prompt': '请根据下面的新闻生成摘要, 内容如下:新华社受权于18日全文播发修改后的《中华人民共和国立法法》,修改后的立法法分为“总则”“法律”“行政法规”“地方性法规、自治条例和单行条例、规章”“适用与备案审查”“附则”等6章,共计105条。\\n生成摘要如下:',\n", + " 'response': [['修改后的立法法全文公布']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:一辆小轿车,一名女司机,竟造成9死24伤。日前,深圳市交警局对事故进行通报:从目前证据看,事故系司机超速行驶且操作不当导致。目前24名伤员已有6名治愈出院,其余正接受治疗,预计事故赔偿费或超一千万元。\\n生成摘要如下:',\n", + " 'response': [['深圳机场9死24伤续:司机全责赔偿或超千万']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:1月18日,习近平总书记对政法工作作出重要指示:2014年,政法战线各项工作特别是改革工作取得新成效。新形势下,希望全国政法机关主动适应新形势,为公正司法和提高执法司法公信力提供有力制度保障。\\n生成摘要如下:',\n", + " 'response': [['孟建柱:主动适应形势新变化提高政法机关服务大局的能力']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:针对央视3·15晚会曝光的电信行业乱象,工信部在公告中表示,将严查央视3·15晚会曝光通信违规违法行为。工信部称,已约谈三大运营商有关负责人,并连夜责成三大运营商和所在省通信管理局进行调查,依法依规严肃处理。\\n生成摘要如下:',\n", + " 'response': [['工信部约谈三大运营商严查通信违规']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:国家食药监管总局近日发布《食品召回管理办法》,明确:食用后已经或可能导致严重健康损害甚至死亡的,属一级召回,食品生产者应在知悉食品安全风险后24小时内启动召回,且自公告发布之日起10个工作日内完成召回。\\n生成摘要如下:',\n", + " 'response': [['食品一级召回限24小时内启动10工作日完成']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:人民检察院刑事诉讼涉案财物管理规定明确,不得查封、扣押、冻结与案件无关的财物,严禁在立案前查封、扣押、冻结财物,对查明确实与案件无关的,应当在三日内予以解除、退还。\\n生成摘要如下:',\n", + " 'response': [['最高检:诉讼未终结涉案财物不得上缴国库']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:聂树斌案的复查有了新进展,山东高院已正式通知聂树斌案申诉代理律师阅卷。这也是该案律师10年来首次获准阅卷。此前山东省高院复查聂树斌案合议庭成员提讯了王书金,其仍坚称石家庄西郊玉米地强奸杀人案是他所为。\\n生成摘要如下:',\n", + " 'response': [['聂树斌案律师10年来首获准阅卷']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:近日北京市司法局、北京市检察院联合召开人民监督员选任管理方式改革工作会议,宣布了北京市第一届113名人民监督员的任命决定,并组织监督员向宪法集体宣誓。人民监督员以“第三方”去监督检察院办案,机制上做到了相互制衡。\\n生成摘要如下:',\n", + " 'response': [['北京市第一届人民监督员向宪法集体宣誓']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:据315晚会报道,公共免费WIFI存在隐患。黑客可利用轻易盗取用户个人信息,如账号、密码等。为了保证您个人信息安全,在公共场所尽量不要使用那些不需要密码免费wifi。\\n生成摘要如下:',\n", + " 'response': [['免费公共wifi存隐患黑客可轻易获取用户信息']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:3月12日,最高人民法院院长周强作最高人民法院工作报告。周强表示,去年各级法院再审改判刑事案件1317件,其中纠正一批重大冤假错案。对错案的发生,我们深感自责,要求各级法院深刻汲取教训……更多工作报告要点详见↓\\n生成摘要如下:',\n", + " 'response': [['一张图看懂最高法2014年工作报告']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:如今,用手机软件叫“代驾”是件时尚而方便的事情,可是一旦发生交通事故,车主、代驾人、手机代驾软件及保险公司,究竟由谁来付“代价”?上海浦东新区法院9日宣判一起新类型案件,软件运行公司被判担责。\\n生成摘要如下:',\n", + " 'response': [['全国首例涉代驾软件交通事故案宣判']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:中央全面深化改革领导小组第十次会议审议通过了《深化人民监督员制度改革方案》,提出由司法行政机关负责人民监督员的选任和培训等管理工作,解决了检察机关“自己选人监督自己”的问题,使人民监督员的监督更有说服力....>>\\n生成摘要如下:',\n", + " 'response': [['全面深化人民监督员制度改革']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:(民主与法制网)“老赖”,就是欠钱不还的人。老赖如同阳光下的隐形人,人人喊打但无迹可寻,数量不多甚至寥寥可数,却是让社会头痛的“老大难”。曾有一位法院工作人员对记者无奈地表示:“这些老赖一而再再而三地推脱赖账,很多案件一时难以执行下去。\\n生成摘要如下:',\n", + " 'response': [['多部门联合捉“老赖”']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:国家税务总局透露,去年税务协同公安等部门共查处发票违法犯罪案件10.2万起,抓获犯罪嫌疑人6014人,缴获假发票6449万份。其中税务部门查处违法企业9.9万户,查补税款106亿元,加收滞纳金10亿元,罚款和没收违法所得18亿元。\\n生成摘要如下:',\n", + " 'response': [['去年全国查处发票违法犯罪案件10.2万起']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:环保部决定对存在各类问题的63家建设项目环评机构和22名环评工程师分别作出了取消资质、缩减评价范围、限期整改、通报批评等相应处理。要求所有具有环保部门背景的环评单位全面退出建设项目环评技术服务市场,彻底解决“红顶中介”问题。\\n生成摘要如下:',\n", + " 'response': [['环保部要求有背景单位退出环评市场']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:最高法5日公布《关于民事审判监督程序严格依法适用指令再审和发回重审若干问题的规定》,规范民事案件指令再审和发回重审的标准,确保再审程序充分发挥依法纠错功能。该司法解释自2015年3月15日起施行\\n生成摘要如下:',\n", + " 'response': [['最新司法解释:确保民事案件再审程序充分发挥纠错功能']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:4日,最高法、最高检、公安部和司法部联合发布《关于依法办理家庭暴力犯罪案件的意见》,意见明确,对正在进行的家庭暴力采取制止行为,只要符合刑法规定的条件,就应当依法认定为正当防卫,不负刑事责任。\\n生成摘要如下:',\n", + " 'response': [['四部门明确认定家暴正当防卫:免刑罚或轻判']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:对于消费者网购后悔权,即“七天内无理由可退货”,商家常以“已拆封”等为由拒绝执行。3月15日起,国家工商总局明确规定,“已拆封”不得作为拒退货理由,故意拒绝或拖延退货商家最高将受到50万元处罚。\\n生成摘要如下:',\n", + " 'response': [['工商总局:“已拆封”不能作为拒退货理由']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:从选种、育苗、种植、田间管理、加工方法、到冲泡手法等,福山咖啡有一套非常完备的专业技术,福山咖啡品质与口味始终如一。在2015年海报集团年货展上,澄迈福山咖啡联合公司将带着浓郁的福山咖啡给前来展会的市民品尝。\\n生成摘要如下:',\n", + " 'response': [['澄迈福山咖啡口感纯正,想尝尝吗']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:最近,中国载人航天网发布了长征7号火箭矗立在文昌新型移动发射平台上的照片。据悉,长征7号火箭可能在2016年进行首次发射,身处海南亲们到时候记得到文昌围观火箭发射哦~\\n生成摘要如下:',\n", + " 'response': [['长征7号来了可以在文昌围观火箭发射咯']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:周末采摘杨桃助农,游玩澄迈养生农庄自驾活动,快来报名~1月31日早上8点半出发哦,前往澄迈杨桃园采摘,入园免费,漫山遍野的杨桃和你亲密接触(带走的按2.5元/斤)。报名方式:拨136-3755-3497报名,或戳链接\\n生成摘要如下:',\n", + " 'response': [['这个周末自驾游去摘杨桃啦!']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:海南省质量技术监督局近日通报了《2014年烟花爆竹产品质量省级监督抽查结果》,海南有2家企业4种产品烟花爆竹不合格,防安、好日子商标名列其中。戳链接查看具体贵规格型号大家要认准了,别买到不合格的哦\\n生成摘要如下:',\n", + " 'response': [['好日子是“车大炮”?过年这些鞭炮你们别买']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:海口秀英长流镇墟上的杨小姐寄包裹却遭快递员调戏:让我亲一口就不收你钱。寄个快递,没想到却被快递员出言调戏。丈夫上门讨说法,却还快递员带一伙人手持长棍大刀石头砸了自己的车。\\n生成摘要如下:',\n", + " 'response': [['海口快递小哥:让我亲一口就不收你钱!']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:歌手尹相杰涉嫌非法持有毒品案于今日上午9时在北京市朝阳区人民法院依法公开开庭审理。对公诉机关指控的事实和罪名,尹相杰均无异议并自愿认罪。法官当庭宣布,尹相杰因非法持有毒品罪,一审获刑7个月并处罚金2000元。\\n生成摘要如下:',\n", + " 'response': [['歌手尹相杰非法持有毒品一审获刑7个月']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:在回答有关工商局和淘宝争议问题时,国家工商总局副局长刘玉亭表示,互联网企业要加强与政府有关部门的沟通,积极反映发展中遇到的问题并提出积极有建设性的意见和建议,政府有关部门要加强对企业的指导。\\n生成摘要如下:',\n", + " 'response': [['工商总局回应与淘宝争议:政府要加强对企业的指导']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:今天湖南沅江市检察院对陆勇涉嫌“妨害信用卡管理“和”销售假药”案做出最终决定,认为其行为不构成犯罪,决定不起诉。陆勇因给千余网友分享购买仿制格列卫的印度抗癌药渠道被称“抗癌药代购第一人”,后被检方起诉。\\n生成摘要如下:',\n", + " 'response': [['检方决定不起诉“抗癌药代购第一人”']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:《互联网用户账号名称管理规定》将于3月1日正式施行。昨日,国家互联网信息办公室有关负责人表示,多家互联网企业近日已处置微博客、博客、论坛、贴吧和即时通信工具等各类违法违规账号6万余个。\\n生成摘要如下:',\n", + " 'response': [['网信办:“账号十条”将施行多网站自查6万个违规账号']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:下周最高法第一巡回法庭将开庭审理第一宗案件,最高法巡回法庭历史上的首宗案件,有怎样的特别之处?巡回法庭将首次实行主审法官、合议庭办案负责制,这意味着第一宗案件主审法官将敲响这一制度历史上的“第一槌”\\n生成摘要如下:',\n", + " 'response': [['最高法第一巡回法庭将开审首案:跨省买卖合同纠纷']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:针对政府采购活动中也暴露出质量不高、效率低下等问题,特制定该条例,完善政府采购制度,进一步促进政府采购的规范化、法制化,构建规范透明、公平竞争、监督到位、严格问责的政府采购工作机制。\\n生成摘要如下:',\n", + " 'response': [['政府采购法实施条例3月1日施行']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:公安部会同国家网信办、工信部、环保部、工商总局、安监总局制定出台《互联网危险物品信息发布管理规定》,进一步加强对互联网危险物品信息的管理,规范危险物品从业单位信息发布行为。规定自2015年3月1日起执行。\\n生成摘要如下:',\n", + " 'response': [['新规:禁止个人在互联网上发布危险物品信息']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:银行卡含有大量信息,例如公民个人身份信息、账户信息等。能否收藏他人银行卡,收藏都有哪些法律风险以及容许自己的卡进入收藏流通领域对于持卡人会不会具有不利法律后果等问题?戳长微博了解↓↓↓\\n生成摘要如下:',\n", + " 'response': [['收藏银行卡隐藏多重法律风险5张以上就有可能犯罪']]}],\n", + " [{'prompt': '请根据下面的新闻生成摘要, 内容如下:说起熊孩子,就想到“牛爸爸”。1月23日傍晚,因女儿要结婚,在公园保安部当部长的老爸就可以开方便之门,让客人把车停到公园里面。八项规定都这么久了,这位保安部长还敢这么干,可真是位“牛爸爸”。\\n生成摘要如下:',\n", + " 'response': [['“熊孩子”VS“牛爸爸”']]}]]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from qianfan.dataset import Dataset\n", + "from qianfan.trainer import LLMFinetune\n", + "\n", + "# 加载本地数据集\n", + "ds: Dataset = Dataset.load(data_file=\"./news_digest.jsonl\")\n", + "ds.list()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "True" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "# 保存到千帆平台\n", + "from qianfan.dataset import QianfanDataSource\n", + "from qianfan.resources.console import consts as console_consts\n", + "\n", + "bos_bucket_name = \"bos_bucket_name\"\n", + "bos_bucket_file_path = \"/data_file_path/\"\n", + "\n", + "\n", + "# 创建千帆数据集,并上传保存\n", + "qianfan_data_source = QianfanDataSource.create_bare_dataset(\n", + " name=\"sdk_trainer_ds_04\",\n", + " template_type=console_consts.DataTemplateType.NonSortedConversation,\n", + " storage_type=console_consts.DataStorageType.PrivateBos,\n", + " storage_id=bos_bucket_name,\n", + " storage_path=bos_bucket_file_path,\n", + ")\n", + "\n", + "ds.save(qianfan_data_source, replace_source=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "receive: < {\"action_id\": \"Pipeline_Ppbl7XMXuu\", \"action_state\": \"Preceding\", \"description\": \"action_event: action[Ppbl7XMXuu], msg:\", \"data\": {}}\n", + "receive: < {\"action_id\": \"Pipeline_Ppbl7XMXuu\", \"action_state\": \"Running\", \"description\": \"action_event: action[Ppbl7XMXuu], msg:pipeline running\", \"data\": {\"action\": \"FdG0AG6Gng\"}}\n", + "receive: < {\"action_id\": \"LoadDataSetAction_FdG0AG6Gng\", \"action_state\": \"Preceding\", \"description\": \"action_event: action[FdG0AG6Gng], msg:\", \"data\": {}}\n", + "receive: < {\"action_id\": \"LoadDataSetAction_FdG0AG6Gng\", \"action_state\": \"Done\", \"description\": \"action_event: action[FdG0AG6Gng], msg:\", \"data\": {\"datasets\": [{\"id\": 38039, \"type\": 1}]}}\n", + "receive: < {\"action_id\": \"Pipeline_Ppbl7XMXuu\", \"action_state\": \"Running\", \"description\": \"action_event: action[Ppbl7XMXuu], msg:pipeline running\", \"data\": {\"action\": \"sXRN83BjSL\"}}\n", + "receive: < {\"action_id\": \"TrainAction_sXRN83BjSL\", \"action_state\": \"Preceding\", \"description\": \"action_event: action[sXRN83BjSL], msg:\", \"data\": {}}\n" + ] + } + ], + "source": [ + "# 对于需要在训练过程中监控每个阶段的各个节点的用户,可以通过事件回调函数来实现\n", + "from qianfan.trainer.event import Event, EventHandler\n", + "from qianfan.trainer.consts import ActionState\n", + "from qianfan.resources.console import consts as console_consts\n", + "from qianfan.trainer.configs import TrainConfig\n", + "from qianfan.trainer.base import Pipeline\n", + "from qianfan.trainer.model import Service, DeployConfig\n", + "from qianfan.resources import QfMessages\n", + "from typing import cast\n", + "\n", + "\n", + "testset: Dataset = Dataset.load(data_file=\"./news_digest_test.jsonl\")\n", + "# 定义自己的EventHandler,并实现dispatch方法\n", + "class InferAfterSFT(EventHandler):\n", + " target_action: str\n", + " def __init__(self, target_action: str) -> None:\n", + " super().__init__()\n", + " self.target_action = target_action\n", + "\n", + " def dispatch(self, event: Event) -> None:\n", + " print(\"receive: <\", event)\n", + " if self.target_action == event.action_id and event.action_state == ActionState.Done:\n", + " svc = cast(Service, event.data[\"service\"])\n", + " print(\"svc\", svc)\n", + " for row in testset.list():\n", + " msgs = QfMessages()\n", + " msgs.append(row[0][0][\"prompt\"], \"user\")\n", + " svc.exec({\"messages\":\"msgs\"})\n", + " print(\"row infer result\", row)\n", + " \n", + "trainer = LLMFinetune(\n", + " train_type=\"ERNIE-Bot-turbo-0516\",\n", + " dataset=ds,\n", + " train_config=TrainConfig(\n", + " epoch=1,\n", + " learning_rate=0.00003,\n", + " max_seq_len=4096,\n", + " peft_type=\"LoRA\",\n", + " ),\n", + " deploy_config=DeployConfig(\n", + " name=\"fin_eb_04\",\n", + " replicas=1,\n", + " pool_type=console_consts.DeployPoolType.PrivateResource,\n", + " ),\n", + ")\n", + "eh = InferAfterSFT(target_action=trainer.ppls[0].id)\n", + "trainer.register_event_handler(eh)\n", + "trainer.run()\n", + "\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "58f7cb64c3a06383b7f18d2a11305edccbad427293a2b4afa7abe8bfc810d4bb" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/src/qianfan/dataset/data_source.py b/src/qianfan/dataset/data_source.py index 2b3f1507..6b5083f4 100644 --- a/src/qianfan/dataset/data_source.py +++ b/src/qianfan/dataset/data_source.py @@ -1009,6 +1009,10 @@ def release_dataset(self) -> bool: Returns: bool: Whether releasing succeeded """ + info = Data.get_dataset_info(self.id)["result"]["versionInfo"] + status = info["releaseStatus"] + if status == DataReleaseStatus.Finished: + return True Data.release_dataset(self.id) while True: sleep(get_config().RELEASE_STATUS_POLLING_INTERVAL) diff --git a/src/qianfan/dataset/dataset.py b/src/qianfan/dataset/dataset.py index 4d1fffd9..b5b3405f 100644 --- a/src/qianfan/dataset/dataset.py +++ b/src/qianfan/dataset/dataset.py @@ -358,6 +358,7 @@ def save( qianfan_dataset_id: Optional[int] = None, qianfan_dataset_create_args: Optional[Dict[str, Any]] = None, schema: Optional[Schema] = None, + replace_source: bool = False, **kwargs: Any, ) -> bool: """ @@ -379,6 +380,8 @@ def save( default to None schema: (Optional[Schema]): schema used to validate before exporting data, default to None + replace_source: (bool): + if replace the original source, default to False kwargs (Any): optional arguments Returns: @@ -416,7 +419,10 @@ def save( kwargs["is_annotated"] = schema.is_annotated # 开始写入数据 - return self._to_source(source, **kwargs) # noqa + res = self._to_source(source, **kwargs) # noqa + if res and replace_source: + self.inner_data_source_cache = source + return res @classmethod def create_from_pyobj( diff --git a/src/qianfan/tests/trainer_test.py b/src/qianfan/tests/trainer_test.py index bdb38eda..30d8aa5e 100644 --- a/src/qianfan/tests/trainer_test.py +++ b/src/qianfan/tests/trainer_test.py @@ -23,7 +23,7 @@ from qianfan.trainer.consts import ServiceType from qianfan.trainer.event import Event, EventHandler from qianfan.trainer.finetune import LLMFinetune -from qianfan.trainer.model import Model +from qianfan.trainer.model import Model, Service class MyEventHandler(EventHandler): @@ -148,3 +148,10 @@ def test_model_deploy(): resp = svc.exec({"messages": [{"content": "hi", "role": "user"}]}) assert resp["result"] != "" + + +def test_service(): + svc = Service(model="ERNIE-Bot", service_type=ServiceType.Chat) + resp = svc.exec({"messages": [{"content": "hi", "role": "user"}]}) + assert resp is not None + assert resp["result"] != "" diff --git a/src/qianfan/trainer/actions.py b/src/qianfan/trainer/actions.py index 27d8635a..19da342d 100644 --- a/src/qianfan/trainer/actions.py +++ b/src/qianfan/trainer/actions.py @@ -21,7 +21,6 @@ from qianfan.resources.console import consts as console_consts from qianfan.trainer.base import ( BaseAction, - EventHandler, with_event, ) from qianfan.trainer.configs import DefaultTrainConfigMapping, DeployConfig, TrainConfig @@ -55,10 +54,9 @@ class LoadDataSetAction(BaseAction[Dict[str, Any], Dict[str, Any]]): def __init__( self, dataset: Optional[Dataset] = None, - event_handler: Optional[EventHandler] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: - super().__init__(event_handler=event_handler) + super().__init__(**kwargs) self.dataset = dataset @with_event @@ -73,6 +71,7 @@ def exec(self, input: Dict[str, Any] = {}, **kwargs: Dict) -> Dict[str, Any]: ) log_debug("[load_dataset_action] prepare train-set") qf_data_src = cast(QianfanDataSource, self.dataset.inner_data_source_cache) + print("==>") is_released = qf_data_src.release_dataset() if not is_released: raise InvalidArgumentError("dataset must be released") @@ -365,16 +364,14 @@ class DeployAction(BaseAction[Dict[str, Any], Dict[str, Any]]): model_id: Optional[int] model_version_id: Optional[int] - def __init__( - self, deploy_config: Optional[DeployConfig] = None, **kwargs: Dict[str, Any] - ): + def __init__(self, deploy_config: Optional[DeployConfig] = None, **kwargs: Any): """ Parameters: deploy_config (Optional[DeployConfig], optional): deploy config include replicas and so on. Defaults to None. """ - super().__init__(kwargs=kwargs) + super().__init__(**kwargs) self.deploy_config = deploy_config @with_event diff --git a/src/qianfan/trainer/base.py b/src/qianfan/trainer/base.py index 7566ccb6..82cbb7d1 100644 --- a/src/qianfan/trainer/base.py +++ b/src/qianfan/trainer/base.py @@ -100,8 +100,8 @@ def __init__( event_handler (Optional[EventHandler], optional): event_handler implements for action state track. Defaults to None. """ - self.id = id if id is not None else utils.uuid() - self.name = name if name is not None else f"actions_{self.id}" + self.id = id if id is not None else utils.generate_letter_num_random_id() + self.name = name if name is not None else f"action_{self.id}" self.state = ActionState.Preceding self.event_dispatcher = event_handler @@ -192,6 +192,10 @@ def action_event(self, state: ActionState, msg: str = "", data: Any = None) -> N ), ) + @classmethod + def action_type(cls) -> str: + return "base" + def with_event(func: Callable[..., Any]) -> Callable[..., Any]: """ @@ -329,6 +333,22 @@ def stop(self, **kwargs: Dict) -> None: return super().stop() + def register_event_handler( + self, event_handler: EventHandler, action_id: Optional[str] = None + ) -> None: + """ + Register the event handler to specific the action. + Args: + event_handler (EventHandler): The event handler instance. + """ + self.event_dispatcher = event_handler + for id, action in self.actions.items(): + if action_id is None and id == action_id: + action.event_dispatcher == event_handler + break + else: + action.event_dispatcher = event_handler + class Trainer(ABC): """ @@ -397,3 +417,18 @@ def get_log(self) -> Any: Receive the training log during the pipeline execution. [coming soon]. """ raise NotImplementedError("trainer get_log") + + def register_event_handler( + self, event_handler: EventHandler, ppl_id: Optional[str] = None + ) -> None: + """ + Register the event handler to specific the ppls. + Args: + event_handler (EventHandler): The event handler instance. + """ + for ppl in self.ppls: + if ppl_id is None and ppl.id == ppl_id: + ppl.register_event_handler(event_handler) + break + else: + ppl.register_event_handler(event_handler) diff --git a/src/qianfan/trainer/consts.py b/src/qianfan/trainer/consts.py index 1a95d497..8501fe0a 100644 --- a/src/qianfan/trainer/consts.py +++ b/src/qianfan/trainer/consts.py @@ -12,6 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. from enum import Enum +from typing import Any, Dict + +from qianfan.resources import ChatCompletion, Completion, Embedding, Text2Image class ActionState(str, Enum): @@ -36,7 +39,15 @@ class ActionState(str, Enum): class FinetuneStatus(str, Enum): Unknown = "Unknown" """未知状态""" - Created = "Created" + DatasetLoading = "DatasetLoading" + """数据集加载中""" + DatasetLoaded = "DatasetLoaded" + """数据集加载完成""" + DatasetLoadFailed = "DatasetLoadFailed" + """数据集加载失败""" + DatasetLoadStopped = "DatasetLoadStopped" + """数据集停止加载""" + TrainCreated = "TrainCreated" """任务创建,初始化""" Training = "Training" """训练中 对应训练任务运行时API状态的Running""" @@ -92,3 +103,12 @@ class ServiceType(str, Enum): """Corresponding to the `Embedding`""" Text2Image = "Text2Image" """Corresponding to the `Text2Image""" + + +# service type -> resources class +ServiceTypeResourcesMapping: Dict[ServiceType, Any] = { + ServiceType.Chat: ChatCompletion, + ServiceType.Completion: Completion, + ServiceType.Embedding: Embedding, + ServiceType.Text2Image: Text2Image, +} diff --git a/src/qianfan/trainer/finetune.py b/src/qianfan/trainer/finetune.py index f9e0c800..28c4cda9 100644 --- a/src/qianfan/trainer/finetune.py +++ b/src/qianfan/trainer/finetune.py @@ -121,7 +121,7 @@ def __init__( event_handler=event_handler, **kwargs, ) - self.model_publish = ModelPublishAction() + self.model_publish = ModelPublishAction(event_handler=event_handler) actions = [ self.load_data_action, @@ -131,7 +131,7 @@ def __init__( if deploy_config is not None: self.deploy_action = DeployAction( deploy_config=deploy_config, - **{"event_handler": event_handler, **kwargs}, + event_handler=event_handler, ) actions.append(self.deploy_action) @@ -208,8 +208,15 @@ def resume(self, **kwargs: Dict) -> "LLMFinetune": # mapping for action state -> fine-tune status fine_tune_action_mapping: Dict[str, Dict[str, Any]] = { + LoadDataSetAction.__class__.__name__: { + ActionState.Preceding: FinetuneStatus.DatasetLoading, + ActionState.Running: FinetuneStatus.DatasetLoading, + ActionState.Done: FinetuneStatus.DatasetLoaded, + ActionState.Error: FinetuneStatus.DatasetLoadFailed, + ActionState.Stopped: FinetuneStatus.DatasetLoadStopped, + }, TrainAction.__class__.__name__: { - ActionState.Preceding: FinetuneStatus.Created, + ActionState.Preceding: FinetuneStatus.TrainCreated, ActionState.Running: FinetuneStatus.Training, ActionState.Done: FinetuneStatus.TrainFinished, ActionState.Error: FinetuneStatus.TrainFailed, diff --git a/src/qianfan/trainer/model.py b/src/qianfan/trainer/model.py index 75ffa642..89c019d3 100644 --- a/src/qianfan/trainer/model.py +++ b/src/qianfan/trainer/model.py @@ -29,7 +29,7 @@ from qianfan.trainer.base import ExecuteSerializable from qianfan.trainer.configs import DeployConfig from qianfan.trainer.consts import ServiceType -from qianfan.utils import log_warn +from qianfan.utils import log_info, log_warn class Model( @@ -39,7 +39,7 @@ class Model( """remote model id""" version_id: Optional[int] """remote model version id""" - name: str + name: Optional[str] = None """model name""" service: Optional["Service"] = None """model service""" @@ -54,6 +54,7 @@ def __init__( version_id: Optional[int] = None, task_id: Optional[int] = None, job_id: Optional[int] = None, + name: Optional[str] = None, ): """ Class for model in qianfan, which is deployable by using deploy() to @@ -73,6 +74,7 @@ def __init__( self.version_id = version_id self.task_id = task_id self.job_id = job_id + self.name = name def exec( self, input: Optional[Dict] = None, **kwargs: Dict @@ -113,8 +115,9 @@ def deploy(self, deploy_config: DeployConfig) -> "Service": Returns: Service: model service instance """ - self.service = model_deploy(self, deploy_config) - + if self.service is None: + self.service = model_deploy(self, deploy_config) + log_info("model service already existed") return self.service def publish(self, name: str = "") -> "Model": @@ -254,7 +257,7 @@ def __init__( self, id: Optional[int] = None, endpoint: Optional[str] = None, - model: Optional[Model] = None, + model: Optional[Union[Model, str]] = None, deploy_config: Optional[DeployConfig] = None, service_type: Optional[ServiceType] = None, ) -> None: @@ -276,15 +279,27 @@ def __init__( Defaults to None. """ self.id = id - self.endpoint = endpoint - self.model = model + self.service_type = service_type + if self.service_type is None: + log_warn("service type should be specified before exec") + if endpoint is not None: + self.model = None + self.endpoint = endpoint + elif isinstance(model, str): + self.model = Model(name=model) + self.endpoint = None + elif isinstance(model, Model): + # need to deploy + self.model = model + self.endpoint = None + else: + raise InvalidArgumentError("invalid model service") self.deploy_config = deploy_config self.service_type = service_type - if self.endpoint is not None and self.service_type is None: - log_warn("service type should be specified when endpoint passed in") + # if self.endpoint is not None and self.service_type is None: @property - def status(self) -> console_const.ServiceStatus: + def status(self) -> str: """ get the service status @@ -295,8 +310,9 @@ def status(self) -> console_const.ServiceStatus: console_const.ServiceStatus """ if self.id is None: - raise InternalError("service id not found") - resp = api.Service.get(id=self.id) + return "" + else: + resp = api.Service.get(id=self.id) return resp["result"]["serviceStatus"] def exec( @@ -318,23 +334,87 @@ def exec( """ if input is None: raise InvalidArgumentError("input is none") + return self.get_res().do(**{**input, **kwargs}) + + def get_res(self) -> Union[ChatCompletion, Completion, Embedding, Text2Image]: + """ + convert to the specific model resources. e.g. + `ChatCompletion`, `Completion`, `Embeddings`, + `Text2Image` + + Returns: + Union[ChatCompletion, Completion, Embedding, Text2Image]: + resource object + """ if self.endpoint is not None and self.service_type is None: raise InvalidArgumentError( "service type must be specified when endpoint passed in" ) - if self.status != console_const.ServiceStatus.Done: - raise InternalError("service is not ready") + svc_status = self.status + if svc_status != console_const.ServiceStatus.Done: + log_warn("service status unknown, service could be unavailable.") if self.service_type == ServiceType.Chat: - return ChatCompletion().do(endpoint=self.endpoint, **input) + return ChatCompletion( + model=(self.model.name if self.model is not None else None), + endpoint=self.endpoint, + ) elif self.service_type == ServiceType.Completion: - return Completion().do(endpoint=self.endpoint, **input) + return Completion( + model=(self.model.name if self.model is not None else None), + endpoint=self.endpoint, + ) elif self.service_type == ServiceType.Embedding: - return Embedding().do(endpoint=self.endpoint, **input) + return Embedding( + model=(self.model.name if self.model is not None else None), + endpoint=self.endpoint, + ) elif self.service_type == ServiceType.Text2Image: - return Text2Image().do(endpoint=self.endpoint, **input) + return Text2Image( + model=(self.model.name if self.model is not None else None), + endpoint=self.endpoint, + ) else: raise InvalidArgumentError(f"unsupported service type {self.service_type}") + def deploy(self) -> "Service": + if self.model is None: + raise InvalidArgumentError("model not found") + model = self.model + if model.id is None or model.version_id is None: + raise InvalidArgumentError("model id | model version id not found") + if self.deploy_config is None: + raise InvalidArgumentError("deploy config not found") + svc_publish_resp = api.Service.create( + model_id=model.id, + model_version_id=model.version_id, + iteration_id=model.version_id, + name=f"svc{model.id}{model.version_id}", + uri=( + self.deploy_config.endpoint_prefix + if self.deploy_config != "" + else f"ep{model.id}{model.version_id}" + ), + replicas=self.deploy_config.replicas, + pool_type=self.deploy_config.pool_type, + ) + + self.id = svc_publish_resp["result"]["serviceId"] + if self.id is None: + raise InternalError("service id not found") + # 资源付费完成后,serviceStatus会变成Deploying,查看模型服务状态 + while True: + resp = api.Service.get(id=self.id) + svc_status = resp["result"]["serviceStatus"] + if svc_status != console_const.ServiceStatus.Deploying.value: + sft_model_endpoint = resp["result"]["uri"] + break + else: + log_info("please check console for service deployment") + time.sleep(get_config().DEPLOY_STATUS_POLLING_INTERVAL) + + self.endpoint = sft_model_endpoint + return self + def dumps(self) -> Optional[bytes]: """ serialize the model instance to bytes @@ -379,33 +459,5 @@ def model_deploy(model: Model, deploy_config: DeployConfig) -> Service: deploy_config=deploy_config, service_type=deploy_config.service_type, ) - if model.id is None or model.version_id is None: - raise InvalidArgumentError("model id | model version id not found") - svc_publish_resp = api.Service.create( - model_id=model.id, - model_version_id=model.version_id, - iteration_id=model.version_id, - name=f"svc{model.id}{model.version_id}", - uri=( - deploy_config.endpoint_prefix - if deploy_config != "" - else f"ep{model.id}{model.version_id}" - ), - replicas=deploy_config.replicas, - pool_type=deploy_config.pool_type, - ) - - svc.id = svc_publish_resp["result"]["serviceId"] - if svc.id is None: - raise InternalError("service id not found") - # 资源付费完成后,serviceStatus会变成Deploying,查看模型服务状态 - while True: - resp = api.Service.get(id=svc.id) - svc_status = resp["result"]["serviceStatus"] - if svc_status != console_const.ServiceStatus.Deploying.value: - sft_model_endpoint = resp["result"]["uri"] - break - time.sleep(get_config().DEPLOY_STATUS_POLLING_INTERVAL) - - svc.endpoint = sft_model_endpoint + svc.deploy() return svc