From da41c4f252da63fc468628a1cba07bf6a310e750 Mon Sep 17 00:00:00 2001 From: lugimzzz <63761690+lugimzzz@users.noreply.github.com> Date: Mon, 16 Dec 2024 16:41:11 +0800 Subject: [PATCH] [llm]add adam-mini (#9542) * add adam-mini * fix following comments --- docs/trainer.md | 3 + llm/docs/dpo.md | 1 + llm/docs/finetune.md | 1 + paddlenlp/trainer/trainer.py | 5 + paddlenlp/trainer/trainer_utils.py | 1 + paddlenlp/trainer/training_args.py | 2 + paddlenlp/utils/__init__.py | 1 + paddlenlp/utils/optimizer.py | 151 +++++++++++++++++++++++++++++ tests/fixtures/llm/adamw_mini.yaml | 35 +++++++ tests/llm/test_adamw_mini.py | 53 ++++++++++ 10 files changed, 253 insertions(+) create mode 100644 paddlenlp/utils/optimizer.py create mode 100644 tests/fixtures/llm/adamw_mini.yaml create mode 100644 tests/llm/test_adamw_mini.py diff --git a/docs/trainer.md b/docs/trainer.md index e5c33f21e848..d643c99268e4 100644 --- a/docs/trainer.md +++ b/docs/trainer.md @@ -691,6 +691,9 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并 --optim 优化器名称,默认为adamw,(`str`, 可选,默认为 `adamw`) The optimizer to use. (default: adamw) + 可能的值为: + - `"adamw"` + - `"adamw_mini"` --report_to 日志可视化显示,默认使用visualdl可视化展示。(可选,默认为 None,展示所有) diff --git a/llm/docs/dpo.md b/llm/docs/dpo.md index 639059ddd0d0..4ca084e5834b 100644 --- a/llm/docs/dpo.md +++ b/llm/docs/dpo.md @@ -119,6 +119,7 @@ python -u -m paddle.distributed.launch --gpus "0,1,2,3,4,5,6,7" ./alignment/dpo - `unified_checkpoint`: 是否使用统一的 checkpoint,默认为 `True`。 - `autotuner_benchmark`: 是否启用 autotuner 基准测试,默认为 `False`。 - `benchmark`: 是否开启基准测试,默认为 `False`。 +- `optim`:默认为`adamw`,支持`adamw`, `adamw_mini`。 ### DPO 参数(DPOArguments) - `beta`: DPO 损失函数的 beta 参数,默认为 0.1。 - `simpo_gamma`: SimPO 损失函数的 gamma 参数,默认为 0.5。 diff --git a/llm/docs/finetune.md b/llm/docs/finetune.md index 9d3d8ffcfb38..3d6f2184a0ff 100644 --- a/llm/docs/finetune.md +++ b/llm/docs/finetune.md @@ -184,6 +184,7 @@ python merge_lora_params.py \ - `pipeline_parallel_degree`: 表示划分流水线的大小.(假设该参数为4, 模型12层, 则每一个 pp stage 包含3层模型) 默认值-1, 表示不启用流水线并行。 - `sharding_parallel_degree`: 表示分组参数切片的数据并行大小. 默认值1, 表示不启用分组参数切片的数据并行。 - `sharding`:是否使用 Paddle 的 Sharding 数据并行功能,用户的参数。支持 sharding `stage1`, `stage2` or `stage3`。其中`stage2``stage3`可以和`offload`组合使用。 +- `optim`:默认为`adamw`,支持`adamw`, `adamw_mini`。 diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 57c655736f25..59d74011e717 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -1915,6 +1915,11 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: optimizer_cls = AdamW optimizer_kwargs.update(adam_kwargs) + elif args.optim == OptimizerNames.ADAMW_MINI: + from ..utils import AdamWMini + + optimizer_cls = AdamWMini + optimizer_kwargs.update(adam_kwargs) else: raise ValueError(f"Trainer cannot instantiate unsupported optimizer: {args.optim}") return optimizer_cls, optimizer_kwargs diff --git a/paddlenlp/trainer/trainer_utils.py b/paddlenlp/trainer/trainer_utils.py index 33ded2ce5bf6..e04f330c6050 100644 --- a/paddlenlp/trainer/trainer_utils.py +++ b/paddlenlp/trainer/trainer_utils.py @@ -317,6 +317,7 @@ class OptimizerNames(ExplicitEnum): ADAMW = "adamw" ADAFACTOR = "adafactor" + ADAMW_MINI = "adamw_mini" class ShardingOption(ExplicitEnum): diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 5d1dad82a831..6f9f501cdc8c 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1018,6 +1018,8 @@ def __post_init__(self): raise ValueError("At most one of fp16 and bf16 can be True for full eval, but not both") self.optim = OptimizerNames(self.optim) + if self.optim == OptimizerNames.ADAMW_MINI and self.tensor_parallel_degree > 1: + raise ValueError("AdamW Mini currently doesn't support tensor parallelism.") self.use_hybrid_parallel = False diff --git a/paddlenlp/utils/__init__.py b/paddlenlp/utils/__init__.py index a8c4dc487a0e..3b5950b0d701 100644 --- a/paddlenlp/utils/__init__.py +++ b/paddlenlp/utils/__init__.py @@ -21,6 +21,7 @@ from .import_utils import * from .infohub import infohub from .initializer import to +from .optimizer import * from .serialization import load_torch # hack impl for EagerParamBase to function diff --git a/paddlenlp/utils/optimizer.py b/paddlenlp/utils/optimizer.py new file mode 100644 index 000000000000..0b2904eb9e53 --- /dev/null +++ b/paddlenlp/utils/optimizer.py @@ -0,0 +1,151 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +from paddle import pir +from paddle.base import core, framework +from paddle.base.framework import Variable, in_dynamic_or_pir_mode, in_pir_mode +from paddle.base.libpaddle import DataType +from paddle.optimizer.adamw import AdamW +from paddle.pir import Value + + +class AdamWMini(AdamW): + def _add_moments_pows(self, p): + acc_dtype = p.dtype + if self._is_dtype_fp16_or_bf16(acc_dtype): + acc_dtype = DataType.FLOAT32 if in_pir_mode() else paddle.float32 + + self._add_accumulator(self._moment1_acc_str, p, dtype=acc_dtype) + # change moment2 + self._add_accumulator(self._moment2_acc_str, p, dtype=acc_dtype, shape=[1]) + try: + type = core.VarDesc.VarType.DENSE_TENSOR + except: + type = core.VarDesc.VarType.LOD_TENSOR + self._add_accumulator( + name=self._beta1_pow_acc_str, + param=p, + dtype=acc_dtype, + fill_value=0.9 if isinstance(self._beta1, (Variable, Value)) else self._beta1, + shape=[1], + type=type, + device="cpu", + ) + self._add_accumulator( + name=self._beta2_pow_acc_str, + param=p, + dtype=acc_dtype, + fill_value=0.999 if isinstance(self._beta2, (Variable, Value)) else self._beta2, + shape=[1], + type=type, + device="cpu", + ) + + def _append_optimize_op(self, block, param_and_grad): + assert isinstance(block, (framework.Block, pir.Block)) + if isinstance(param_and_grad, dict): + param_and_grad = self._update_param_group(param_and_grad) + param = param_and_grad[0] + + # Whether we should do weight decay for the parameter. + with_decay = True + if self._apply_decay_param_fun is not None and not self._apply_decay_param_fun(param.name): + with_decay = False + + moment1 = self._get_accumulator_master(self._moment1_acc_str, param_and_grad[0]) + moment2 = self._get_accumulator_master(self._moment2_acc_str, param_and_grad[0]) + beta1_pow_acc = self._get_accumulator_master(self._beta1_pow_acc_str, param_and_grad[0]) + beta2_pow_acc = self._get_accumulator_master(self._beta2_pow_acc_str, param_and_grad[0]) + find_master = self._multi_precision and self._is_dtype_fp16_or_bf16(param_and_grad[0].dtype) + master_weight = self._master_weights[param_and_grad[0].name] if find_master else None + lr = self._create_param_lr(param_and_grad) + # create the adamw optimize op + if in_dynamic_or_pir_mode(): + lr_ratio_ = 1.0 if self._lr_ratio is None else self._lr_ratio(param_and_grad[0]) + + _beta1 = self._beta1 if not isinstance(self._beta1, Variable) else self._beta1.item(0) + _beta2 = self._beta2 if not isinstance(self._beta2, Variable) else self._beta2.item(0) + + found_inf = self._get_auxiliary_var("found_inf") if in_pir_mode() else None + self.adamw_python( + param_and_grad[0], + param_and_grad[1], + lr, + moment1, + moment2, + beta1_pow_acc, + beta2_pow_acc, + master_weight, + found_inf, + _beta1, + _beta2, + self._epsilon, + lr_ratio_, + self._weight_decay, + with_decay, + find_master, + ) + return None + else: + raise NotImplementedError("Not implemented yet.") + + def adamw_python( + self, + param, + grad, + learning_rate, + moment1, + moment2, + beta1_pow, + beta2_pow, + master_weight, + skip_update, + beta1, + beta2, + epsilon, + lr_ratio, + coeff, + with_decay, + multi_precision, + ): + if skip_update: + return + if not with_decay: + coeff = 0.0 + if not multi_precision: + master_weight = None + lr = learning_rate * lr_ratio + if master_weight is not None: + p = master_weight + else: + p = param + p *= 1.0 - lr * coeff + mom1 = moment1 + mom2 = moment2 + + mom1 = beta1 * mom1 + (1.0 - beta1) * grad + mom2 = beta2 * mom2 + (1.0 - beta2) * (grad * grad).mean() + denom = mom2.sqrt() / (1.0 - beta2_pow).sqrt() + epsilon + p += (moment1 / denom) * (-(lr / (1.0 - beta1_pow))) + if master_weight is not None: + master_weight[:] = p + param[:] = p.astype(param.dtype) + else: + param[:] = p + moment1[:] = mom1 + moment2[:] = mom2 + beta1_pow[:], beta2_pow[:] = beta1 * beta1_pow[:], beta2 * beta2_pow[:] + # 看看怎么更新 + return diff --git a/tests/fixtures/llm/adamw_mini.yaml b/tests/fixtures/llm/adamw_mini.yaml new file mode 100644 index 000000000000..6dc6e9b865b9 --- /dev/null +++ b/tests/fixtures/llm/adamw_mini.yaml @@ -0,0 +1,35 @@ +finetune: + base: + dataset_name_or_path: "./data" + per_device_train_batch_size: 4 + gradient_accumulation_steps: 4 + per_device_eval_batch_size: 8 + eval_accumulation_steps: 16 + num_train_epochs: 3 + learning_rate: 3e-05 + warmup_steps: 30 + logging_steps: 1 + evaluation_strategy: "epoch" + save_strategy: "epoch" + src_length: 1024 + max_length: 2048 + fp16: true + fp16_opt_level: "O2" + do_train: true + do_eval: true + use_flash_attention: true + disable_tqdm: true + load_best_model_at_end: true + eval_with_do_generation: false + metric_for_best_model: "accuracy" + recompute: true + refined_recompute: "flash_attn:-1" + save_total_limit: 1 + tensor_parallel_degree: 1 + pipeline_parallel_degree: 1 + ignore_save_lr_and_optim: 1 + optim: "adamw_mini" + + default: + llama: + model_name_or_path: __internal_testing__/tiny-random-llama diff --git a/tests/llm/test_adamw_mini.py b/tests/llm/test_adamw_mini.py new file mode 100644 index 000000000000..383d82407a06 --- /dev/null +++ b/tests/llm/test_adamw_mini.py @@ -0,0 +1,53 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import sys +import unittest + +from parameterized import parameterized_class + +from tests.testing_utils import argv_context_guard, load_test_config + +from .testing_utils import LLMTest + + +@parameterized_class( + ["model_dir"], + [ + ["llama"], + ], +) +class FinetuneTest(LLMTest, unittest.TestCase): + config_path: str = "./tests/fixtures/llm/adamw_mini.yaml" + model_dir: str = None + + def setUp(self) -> None: + LLMTest.setUp(self) + + sys.path.insert(0, self.model_dir) + + def tearDown(self) -> None: + LLMTest.tearDown(self) + + def test_finetune(self): + finetune_config = load_test_config(self.config_path, "finetune", self.model_dir) + + finetune_config["dataset_name_or_path"] = self.data_dir + finetune_config["output_dir"] = self.output_dir + + with argv_context_guard(finetune_config): + from run_finetune import main + + main()