Skip to content

Commit

Permalink
feat[trainer]: support resource config and increment training from mo…
Browse files Browse the repository at this point in the history
…del (#658)
  • Loading branch information
danielhjz authored Jul 12, 2024
1 parent 2c54b72 commit 62f6349
Show file tree
Hide file tree
Showing 4 changed files with 53 additions and 4 deletions.
4 changes: 3 additions & 1 deletion python/qianfan/common/client/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,9 @@ def print_trainer_config(config: ModelInfo) -> None:
from qianfan.trainer.configs import TrainConfig

limit_fields = (
TrainConfig().dict(exclude={"peft_type", "trainset_rate", "extras"}).keys()
TrainConfig()
.dict(exclude={"peft_type", "trainset_rate", "extras", "resource_config"})
.keys()
)
for k in limit_fields:
row_objs = []
Expand Down
9 changes: 8 additions & 1 deletion python/qianfan/trainer/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,9 @@ def _exec(self, input: Dict[str, Any] = {}, **kwargs: Any) -> Dict[str, Any]:

assert self.train_config is not None
hyper_params_dict = {
**self.train_config.dict(exclude={"peft_type", "trainset_rate", "extras"}),
**self.train_config.dict(
exclude={"peft_type", "trainset_rate", "extras", "resource_config"}
),
**self.train_config.extras,
}
hyper_params_dict = {
Expand All @@ -681,6 +683,11 @@ def _exec(self, input: Dict[str, Any] = {}, **kwargs: Any) -> Dict[str, Any]:
# 语料混合配置
if input.get("corpus_config"):
kwargs["corpus_config"] = input.get("corpus_config")
# 训练资源配置
if self.train_config.resource_config:
kwargs["resource_config"] = self.train_config.resource_config.dict(
by_alias=True, exclude_none=True
)
create_task_resp = api.FineTune.V2.create_task(
job_id=self.job_id,
params_scale=self.train_config.peft_type,
Expand Down
12 changes: 12 additions & 0 deletions python/qianfan/trainer/configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,17 @@ class CorpusConfig(BaseJsonModel):
"""


class ResourceConfig(BaseJsonModel):
resource_id: str = Field(default=[], alias="resourceId")
"""
resource ids
"""
node_num: Optional[int] = Field(default=None, alias="nodeNum")
"""
node num
"""


class BaseTrainConfig(BaseModel):
peft_type: Optional[Union[str, PeftType]] = None
"""
Expand All @@ -111,6 +122,7 @@ class BaseTrainConfig(BaseModel):
"""
extra fields for train_config
"""
resource_config: Optional[ResourceConfig] = None

def validate_config(self, train_limit: "TrainLimit") -> bool:
schema = self.schema()
Expand Down
32 changes: 30 additions & 2 deletions python/qianfan/trainer/finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from qianfan.errors import InvalidArgumentError
from qianfan.evaluation.evaluator import Evaluator
from qianfan.model.configs import DeployConfig
from qianfan.model.model import Model
from qianfan.resources.console import consts as console_consts
from qianfan.trainer.actions import (
DeployAction,
Expand Down Expand Up @@ -64,6 +65,7 @@ def __init__(
dataset_bos_path: Optional[str] = None,
previous_trainer: Optional[Trainer] = None,
previous_task_id: Optional[str] = None,
previous_model: Optional[Any] = None,
name: Optional[str] = None,
**kwargs: Any,
) -> None:
Expand All @@ -73,7 +75,7 @@ def __init__(
Parameters:
train_type: str
A string representing the model version type.
like 'ERNIE-Bot-turbo-0725', 'ChatGLM2-6b'
like 'ERNIE-Speed-8K', 'ChatGLM2-6b'
dataset: Dataset
A dataset instance.
train_config: TrainConfig
Expand All @@ -95,8 +97,16 @@ def __init__(
this will be ignored when dataset is provided.
previous_trainer: Optional[Trainer]
An optional previous trainer instance for incremental training.
incremental training will try to load the model from
`previous_trainer`, `previous_model` or `previous_task_id`
previous_model: Optional[Any]
previous model objet or model_id for incremental training.
incremental training will try to load the model from
`previous_trainer`, `previous_model` or `previous_task_id`
previous_task_id: Optional[str]
An optional previous task id for incremental training.
incremental training will try to load the model from
`previous_trainer`, `previous_model` or `previous_task_id`
name: Optional[str]
An optional name for the training task.
corpus_config: Optional[CorpusConfig] = None,
Expand All @@ -107,7 +117,7 @@ def __init__(
for calling example:
```
sft_task = LLMFinetune(
train_type="ERNIE-Bot-turbo-0725",
train_type="ERNIE-Speed-8K",
dataset={"datasets": [{"type": 1, "id": ds_id}]},
train_config=TrainConfig(...),
event_handler=eh,
Expand Down Expand Up @@ -166,6 +176,24 @@ def __init__(
is_incr=True,
**kwargs,
)
elif previous_model:
if isinstance(previous_model, str):
previous_model = Model(id=previous_model)
if isinstance(previous_model, Model):
from qianfan import resources
assert isinstance(previous_model.id, str)
resp = resources.Model.V2.describe_model(model_id=previous_model.id)
if resp["result"]["sourceInfo"].get("trainTaskId"):
previous_task_id = resp["result"]["sourceInfo"]["trainTaskId"]
self.train_action = TrainAction(
train_config=train_config,
task_id=previous_task_id,
train_mode=console_consts.TrainMode.SFT,
event_handler=event_handler,
job_name=name,
previous_model=previous_model,
**kwargs,
)
else:
# init train action from base model
self.train_action = TrainAction(
Expand Down

0 comments on commit 62f6349

Please sign in to comment.