Skip to content

Commit

Permalink
Merge branch 'main' into fix_client_chinese_input
Browse files Browse the repository at this point in the history
  • Loading branch information
ZingLix committed Jan 10, 2024
2 parents 7119ba5 + 6ad72e2 commit 33aa721
Show file tree
Hide file tree
Showing 20 changed files with 291 additions and 68 deletions.
19 changes: 10 additions & 9 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,10 @@ $ qianfan [OPTIONS] COMMAND [ARGS]...
* `--secret-key TEXT`:百度智能云安全认证 Secret Key,获取方式参考 [文档](https://cloud.baidu.com/doc/Reference/s/9jwvz2egb)
* `--ak TEXT` [过时]:千帆平台应用的 API Key,仅能用于模型推理部分 API,获取方式参考 [文档](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Slkkydake)
* `--sk TEXT` [过时]:千帆平台应用的 Secret Key,仅能用于模型推理部分 API,获取方式参考 [文档](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Slkkydake)
* `--enable-traceback`:打印完整的错误堆栈信息,仅在发生异常时有效。
* `--version -v`:打印版本信息。
* `--install-completion`:为当前 shell 安装自动补全脚本。
* `--show-completion`:展示自动补全脚本。
* `--install-shell-autocomplete`:为当前 shell 安装自动补全脚本。
* `--show-shell-autocomplete`:展示自动补全脚本。
* `--help -h`:打印帮助文档。

**命令**:
Expand Down Expand Up @@ -123,7 +124,7 @@ $ qianfan dataset [OPTIONS] COMMAND [ARGS]...

> ⚠️ 在下方各个数据集的命令中,涉及数据集 id 均指平台上的数据集版本 id,与 Dataset 模块定义一致,具体获取方式参考 [文档](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Uloic6krs)
>
> 使用时可以直接传数据集的 id,也可以使用链接形式避免与文件名产生歧义,格式为 `qianfan://{model_version_id}`,例如 `qianfan://18562`
> 使用时可以直接传数据集的 id,也可以使用链接形式避免与文件名产生歧义,格式为 `qianfan://{dataset_version_id}`,例如 `qianfan://18562`
>
> 如果由于本地文件名为数字,导致和数据集 id 混淆,可以在文件名前增加 `./` 避免歧义,例如 `./18562`
Expand All @@ -139,7 +140,7 @@ $ qianfan dataset predict [OPTIONS] DATASET

**Arguments 参数**:

* `DATASET`:待预测的数据集。值可以是一个本地文件的路径,也可以是平台上的数据集链接 (格式为 `qianfan://{model_version_id}`)。 [required]
* `DATASET`:待预测的数据集。值可以是一个本地文件的路径,也可以是平台上的数据集链接 (格式为 `qianfan://{dataset_version_id}`)。 [required]

**Options 选项**:

Expand Down Expand Up @@ -184,7 +185,7 @@ $ qianfan dataset upload [OPTIONS] PATH [DST]
**Arguments 参数**:

* `SRC`:数据集文件路径。 [required]
* `[DST]`:目标数据集 id,该参数可选。如果不提供该值,那么将会在平台上创建一个新的数据集,否则数据将被追加至所提供的数据集中。值可以是数据集的 id 或者是千帆数据集链接 (qianfan://{model_version_id})。
* `[DST]`:目标数据集 id,该参数可选。如果不提供该值,那么将会在平台上创建一个新的数据集,否则数据将被追加至所提供的数据集中。值可以是数据集的 id 或者是千帆数据集链接 (qianfan://{dataset_version_id})。

**Options 选项**:

Expand All @@ -210,8 +211,8 @@ $ qianfan dataset save [OPTIONS] SRC [DST]

**Arguments 参数**:

* `SRC`:源数据集。值可以是一个本地文件的路径,也可以是平台上的数据集链接 (格式为 `qianfan://{model_version_id}`)。 [required]
* `[DST]`:目标数据集。如果值是一个本地文件路径,那么数据将保存至该文件中。或者可以提供一个已有的千帆数据集链接 (qianfan://{model_version_id}),那么数据将被追加至该数据集中。如果不提供该值,那么将会在平台上创建一个新的数据集,此时需要提供创建数据集所需的一些参数,具体见下文。
* `SRC`:源数据集。值可以是一个本地文件的路径,也可以是平台上的数据集链接 (格式为 `qianfan://{dataset_version_id}`)。 [required]
* `[DST]`:目标数据集。如果值是一个本地文件路径,那么数据将保存至该文件中。或者可以提供一个已有的千帆数据集链接 (qianfan://{dataset_version_id}),那么数据将被追加至该数据集中。如果不提供该值,那么将会在平台上创建一个新的数据集,此时需要提供创建数据集所需的一些参数,具体见下文。

**Options 选项**:

Expand All @@ -233,11 +234,11 @@ $ qianfan dataset view [OPTIONS] DATASET

**Arguments 参数**:

* `DATASET`:待预览的数据集。值可以是一个本地文件的路径,也可以是平台上的数据集链接 (格式为 `qianfan://{model_version_id}`)。[required]
* `DATASET`:待预览的数据集。值可以是一个本地文件的路径,也可以是平台上的数据集链接 (格式为 `qianfan://{dataset_version_id}`)。[required]

**Options 参数**:

* `--row TEXT`:待预览的数据集行。用 `,` 分隔数个行,用 `-` 表示一个范围 (e.g. 1,3-5,12)
* `--row TEXT`:待预览的数据集行。用 `,` 分隔数个行,用 `-` 表示一个范围 (e.g. 1,3-5,12)。默认情况下仅打印前 5 行,可以通过 `--row all` 来打印所有数据。
* `--column TEXT`:待预览的数据集的列。用 `,` 分隔每个列名称。 (e.g. prompt,response)
* `--raw`:展示原始数据。
* `--help`:展示帮助文档。
105 changes: 105 additions & 0 deletions docs/tool.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
# Tool

为了方便将大模型与外部工具集成,千帆SDK提供了一套Tool框架。Tool可以类比成一个函数,它能够被Agent理解并使用,作为LLM与外部世界交互的工具。

基本流程是:Agent控制LLM,根据名称和描述判断用户的输入是否需要使用某个工具,如果确定需要使用一个工具,则进一步生成工具需要的参数,然后Agent使用参数调用对应的执行函数,并将执行函数的结果返回给LLM,再由LLM总结输出,最终返回给用户。

此外,SDK的Tool框架提供了一套转换方法,可以对LangChain等常见框架的Tool进行双向转换,以便集成其他生态。

## Tool类

Tool有两个核心类:

BaseTool:工具的基础类,用于定义工具的基本信息和运行方法。每个工具都必须基于此类实现,并定义名称、描述、参数列表以及实现run方法。

- **name**: 一个非常简短、清晰的名称,就像函数名那样。例如:baidu_search。
- **description**: 工具的描述,解释其功能和用途。例如:使用百度搜索引擎,在互联网上检索任何实时最新的相关信息。
- **parameters**: 调用这个Tool需要的参数,注意,参数应该和执行函数的入参一致。例如:search_query -> 搜索的关键词或短语。
- **run**: 接收参数并执行Tool对应的动作,然后返回执行结果。

ToolParameter:用于定义工具参数,包括参数的名称、类型、描述、属性以及是否为必需参数。

- **name**: 参数的名称。
- **description**: 参数的描述,可以包含其功能和格式要求。
- **type**: 参数的数据类型,如string、integer、object等,对应JSON schema中的类型。
- **properties**: 当参数类型为object时,定义该对象的属性列表。
- **required**: 表示参数是否必须提供。

## 定义工具

下面是一个简单的示例,用于定义并实现一个控制智能家居的工具。

light_switch工具用于控制智能电灯的开关,它接收一个switch的boolean参数用于表示开关状态,然后,我们在工具的run方法中实现控制逻辑。

```python
from typing import List
from qianfan.common.tool.base_tool import ToolParameter, BaseTool

class LightSwitchTool(BaseTool):
name: str = "light_switch"
description: str = "控制智能电灯的开关"
parameters: List[ToolParameter] = [
ToolParameter(
name="switch",
type="boolean",
description="开关状态",
required=True,
)
]

def run(self, parameters):
# 此处编写控制逻辑
return "灯已打开" if parameters["switch"] else "灯已关闭"
```

你可以在实例化一个工具类后直接运行它。

```python
light_switch_tool = LightSwitchTool()
print(light_switch_tool.run({"switch": True}))
```

在这个示例中,你应该会得到以下输出。

```
灯已打开
```

## 与外部能力集成

Tool类提供了以下方法:

- **to_function_call_schema**:将Tool转换为调用function call的JSON Schema。
- **to_langchain_tool**:将Tool转换为适配Langchain框架的Tool,可以被Langchain Agent直接调用。
- **from_langchain_tool**:这是一个静态方法,可以将Langchain框架的Tool实例转换为千帆SDK的Tool。

我们将开发完毕的LightSwitchTool实例化,随后调用to_langchain_tool方法来转换为Langchain的Tool,然后创建LLM和Agent,并传入Tool,最后运行。

```python
import os
from langchain.agents import AgentExecutor
from langchain_community.chat_models import QianfanChatEndpoint
from qianfan.extensions.langchain.agents import QianfanSingleActionAgent

os.environ["QIANFAN_AK"] = "此处填写你的AK"
os.environ["QIANFAN_SK"] = "此处填写你的SK"
tools = [LightSwitchTool().to_langchain_tool()]

llm = QianfanChatEndpoint(model="ERNIE-Bot")
agent = QianfanSingleActionAgent.from_system_prompt(tools, llm)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

print(agent_executor.run("帮我关闭电灯"))
```

在这个示例中,你应该会得到以下输出。

```
content='' additional_kwargs={'id': 'as-6harkfa2rn', 'object': 'chat.completion', 'created': 1704699483, 'result': '', 'is_truncated': False, 'need_clear_history': False, 'function_call': {'name': 'light_switch', 'arguments': '{"switch":false}'}, 'finish_reason': 'function_call', 'usage': {'prompt_tokens': 108, 'completion_tokens': 25, 'total_tokens': 133}}
content='根据你的请求,我已经帮你关闭了电灯。如果你需要打开电灯,请告诉我。' additional_kwargs={'id': 'as-ctg359ensw', 'object': 'chat.completion', 'created': 1704699485, 'result': '根据你的请求,我已经帮你关闭了电灯。如果你需要打开电灯,请告诉我。', 'is_truncated': False, 'need_clear_history': False, 'finish_reason': 'normal', 'usage': {'prompt_tokens': 134, 'completion_tokens': 19, 'total_tokens': 153}}
> Finished chain.
好的,我已经帮您关闭了电灯。如果您需要再次打开电灯,请告诉我。
```
6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ aiolimiter = ">=1.1.0"
importlib-metadata = { version = ">=1.4.0", python = "<=3.7" }
bce-python-sdk = ">=0.8.79"
typing-extensions = { version = ">=4.0.0", python = "<=3.10" }
pydantic = ">=2"
pydantic-settings = ">=2.0.3"
pydantic = "*"
python-dotenv = "<=0.21.1"
langchain = { version = ">=0.0.321", python = ">=3.8.1", optional = true }
numpy = [
{ version = "<1.22.0", python = ">=3.7 <3.8" },
Expand Down Expand Up @@ -73,7 +73,7 @@ preview = true
[tool.mypy]
ignore_missing_imports = "True"
disallow_untyped_defs = "True"
exclude = ["qianfan/tests"]
exclude = ["qianfan/tests", "qianfan/pydantic"]


[build-system]
Expand Down
10 changes: 5 additions & 5 deletions src/qianfan/common/client/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def completion_multi(self, messages: List[str]) -> None:


def completion_entry(
messages: List[str] = typer.Argument(..., help="Messages List"),
prompts: List[str] = typer.Argument(..., help="Prompt List"),
model: str = typer.Option(
DefaultLLMModel.Completion,
help="Model name of the completion model.",
Expand All @@ -102,12 +102,12 @@ def completion_entry(
"""
Complete the provided prompt or messages.
"""
if len(messages) % 2 != 1:
if len(prompts) % 2 != 1:
print_error_msg("The number of messages must be odd.")
raise typer.Exit(code=1)
client = CompletionClient(model, endpoint, plain)

if len(messages) == 1:
client.completion_single(messages[0])
if len(prompts) == 1:
client.completion_single(prompts[0])
else:
client.completion_multi(messages)
client.completion_multi(prompts)
21 changes: 14 additions & 7 deletions src/qianfan/common/client/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def save(
...,
help=(
"The source of the dataset. The value can be a file path or qianfan"
" dataset url (qianfan://{model_version_id})."
" dataset url (qianfan://{dataset_version_id})."
),
),
dst: Optional[str] = typer.Argument(
Expand All @@ -100,7 +100,7 @@ def save(
"The destination of the dataset. The dataset will be saved to a file if the"
" value is a path. Alternatively, the dataset can be appended to an"
" existing dataset on the platform if an qianfan dataset url is provided"
" (qianfan://{model_version_id}). If this value is not provided, a new"
" (qianfan://{dataset_version_id}). If this value is not provided, a new"
" dataset will be created on the platform."
),
),
Expand Down Expand Up @@ -202,7 +202,7 @@ def download(
...,
help=(
"The version id of the dataset on the qianfan platform. The value can be"
" qianfan dataset id or url(qianfan://{model_version_id})."
" qianfan dataset id or url(qianfan://{dataset_version_id})."
),
),
output: Path = typer.Option(Path(f"{timestamp()}.jsonl"), help="Output file path."),
Expand All @@ -227,7 +227,7 @@ def upload(
"The destination of the dataset. If this value is not provided, a new"
" dataset will be created on the platform. Alternatively, the dataset can"
" be appended to an existing dataset on the platform if an qianfan dataset"
" id or url(qianfan://{model_version_id}) is provided . "
" id or url(qianfan://{dataset_version_id}) is provided . "
),
),
dataset_name: Optional[str] = typer.Option(
Expand Down Expand Up @@ -274,14 +274,15 @@ def view(
...,
help=(
"The dataset to view. The value can be a file path or qianfan"
" dataset url (qianfan://{model_version_id})."
" dataset url (qianfan://{dataset_version_id})."
),
),
row: Optional[str] = typer.Option(
None,
help=(
"The row to view. Use commas(,) to view multiple rows and dashes(-) to"
" denote a range of data. (e.g. 1,3-5,12)"
" denote a range of data (e.g. 1,3-5,12). By default, only the top 5 rows"
" will be displayed. Alternatively, use '--row all' to view all rows."
),
),
column: Optional[str] = typer.Option(
Expand All @@ -302,6 +303,12 @@ def view(
# list of (start_idx, end_idx)
row_list = []
if row is None:
print_info_msg(
"No row index provided, only top 5 rows will be displayed. Or use '--row"
" all' to display all rows."
)
row_list.append((0, min(len(ds), 5)))
elif row == "all":
row_list.append((0, len(ds)))
else:
row_l = row.split(",")
Expand Down Expand Up @@ -383,7 +390,7 @@ def predict(
...,
help=(
"The dataset to predict. The value can be a file path or qianfan"
" dataset url (qianfan://{model_version_id})."
" dataset url (qianfan://{dataset_version_id})."
),
),
model: str = typer.Option(
Expand Down
Loading

0 comments on commit 33aa721

Please sign in to comment.