Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support trainer in client #192

Merged
merged 19 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,6 @@ import qianfan

os.environ["QIANFAN_ACCESS_KEY"]="..."
os.environ["QIANFAN_SECRET_KEY"]="..."
# 通过 App Id 选择使用的应用
# 该参数可选,若不提供 SDK 会自动选择最新创建的应用
os.environ["QIANFAN_APPID"]="..."

chat_comp = qianfan.ChatCompletion(model="ERNIE-Bot")
resp = chat_comp.do(messages=[{
Expand Down
3 changes: 0 additions & 3 deletions README.pypi.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,6 @@ import qianfan

os.environ["QIANFAN_ACCESS_KEY"]="..."
os.environ["QIANFAN_SECRET_KEY"]="..."
# 通过 App Id 选择使用的应用
# 该参数可选,若不提供 SDK 会自动选择最新创建的应用
os.environ["QIANFAN_APPID"]="..."

chat_comp = qianfan.ChatCompletion(model="ERNIE-Bot")
resp = chat_comp.do(messages=[{
Expand Down
59 changes: 57 additions & 2 deletions docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ $ qianfan [OPTIONS] COMMAND [ARGS]...

### chat 对话

![](./imgs/cli/chat.gif)
![](./imgs/cli/chat.webp)

**用法**:

Expand Down Expand Up @@ -242,4 +242,59 @@ $ qianfan dataset view [OPTIONS] DATASET
* `--row TEXT`:待预览的数据集行。用 `,` 分隔数个行,用 `-` 表示一个范围 (e.g. 1,3-5,12)。默认情况下仅打印前 5 行,可以通过 `--row all` 来打印所有数据。
* `--column TEXT`:待预览的数据集的列。用 `,` 分隔每个列名称。 (e.g. prompt,response)
* `--raw`:展示原始数据。
* `--help`:展示帮助文档。
* `--help`:展示帮助文档。

## trainer 训练

**用法**:

```console
$ qianfan trainer [OPTIONS] COMMAND [ARGS]...
```

**Options 选项**:

* `--help`:展示帮助文档

**Commands 命令**:

* `run`:运行 trainer 任务

### run

运行 trainer 任务

**用法**:

```console
$ qianfan trainer run [OPTIONS]
```

**Options 选项**:

* `--train-type TEXT`:训练类型 [required]
* `--dataset-id INTEGER`:数据集 id [required]
* `--help`:展示帮助文档

训练相关配置,参数含义与 [训练 API 文档](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/mlmrgo4yx#body%E5%8F%82%E6%95%B0) 中对应参数含义一致:

* `--train-epoch INTEGER`:训练轮数
* `--train-batch-size INTEGER`:训练每轮 batch 的大小
* `--train-learning-rate FLOAT`:学习率
* `--train-max-seq-len INTEGER`:最大训练长度
* `--train-peft-type [all|p_tuning|lo_ra]`:Parameter efficient finetuning 方式
* `--trainset-rate INTEGER`:数据拆分比例 [default:20]
* `--train-logging-steps INTEGER`:日志记录间隔
* `--train-warmup-ratio FLOAT`:预热比例
* `--train-weight-decay FLOAT`:正则化系数
* `--train-lora-rank INTEGER`:LoRA 策略中的秩
* `--train-lora-all-linear TEXT`:LoRA 是否所有均为线性层

部署相关配置,参数含义与 [创建服务 API 文档](https://cloud.baidu.com/doc/WENXINWORKSHOP/s/Plnlmxdgy#body%E5%8F%82%E6%95%B0) 中对应参数含义一致:

* `--deploy-name TEXT`:部署服务名称。设置该值后会开始部署 action。
* `--deploy-endpoint-prefix TEXT`:部署服务的 endpoint 前缀
* `--deploy-description TEXT`:服务描述
* `--deploy-replicas INTEGER`:副本数 [default:1]
* `--deploy-pool-type [public_resource|private_resource]`:资源池类型 [default:private_resource]
* `--deploy-service-type [chat|completion|embedding|text2_image]`:服务类型 [default:chat]
Binary file removed docs/imgs/cli/chat.gif
Binary file not shown.
Binary file added docs/imgs/cli/chat.webp
Binary file not shown.
2 changes: 2 additions & 0 deletions src/qianfan/common/client/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from qianfan.common.client.completion import completion_entry
from qianfan.common.client.dataset import dataset_app
from qianfan.common.client.embedding import embedding_entry
from qianfan.common.client.trainer import trainer_app
from qianfan.common.client.txt2img import txt2img_entry
from qianfan.common.client.utils import print_error_msg, print_info_msg

Expand All @@ -38,6 +39,7 @@
app.command(name="txt2img")(txt2img_entry)
app.command(name="embedding", no_args_is_help=True)(embedding_entry)
app.add_typer(dataset_app, name="dataset")
app.add_typer(trainer_app, name="trainer")

_enable_traceback = False

Expand Down
Loading
Loading