diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index 7606911531ca2..875b54b4862bd 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -336,6 +336,7 @@ def _init_func(var, block): placements=placements, **tensor.__dict__, ) + dist_param.stop_gradient = tensor.stop_gradient if tensor._init_func is not None: origin_init_func = tensor._init_func dist_param.set_init_func( diff --git a/python/paddle/distributed/auto_parallel/intermediate/tensor_parallel.py b/python/paddle/distributed/auto_parallel/intermediate/tensor_parallel.py index 77d42537d1085..02156533d46f4 100644 --- a/python/paddle/distributed/auto_parallel/intermediate/tensor_parallel.py +++ b/python/paddle/distributed/auto_parallel/intermediate/tensor_parallel.py @@ -96,7 +96,7 @@ class PlanBase: def __init__(self): self.share_param_list = {} - def apply(self, layer, process_mesh, shard_weight, shard_bias): + def apply(self, layer, process_mesh, shard_param_list): raise NotImplementedError("Don't call the PlanBase directly.") @@ -151,7 +151,7 @@ def gather_hook(layer, input, output): return gather_hook - def apply(self, layer, process_mesh, shard_weight=True, shard_bias=True): + def apply(self, layer, process_mesh, shard_param_list): index = process_mesh.dim_names.index('mp') # get the axis for the split size = len(process_mesh.shape) placement = [dist.Replicate() for _ in range(size)] @@ -163,41 +163,44 @@ def apply(self, layer, process_mesh, shard_weight=True, shard_bias=True): f"But got {layer.__class__.__name__}. " f"Will try to shard weight and bias if the layer contains one." ) - if ( - hasattr(layer, "weight") - and layer.weight is not None - and shard_weight - ): - placement[index] = dist.Shard(1) - assert len(layer.weight.shape) == 2 - # NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end - if ( - self.share_param_list is not None - and layer.weight.name in self.share_param_list - and self.share_param_list[layer.weight.name] > 1 - ): - param_placements.update({"weight": placement}) - else: - layer.weight = dist.shard_tensor( - layer.weight, - process_mesh, - placement, - ) - if hasattr(layer, "bias") and layer.bias is not None and shard_bias: - placement[index] = dist.Shard(0) - assert len(layer.bias.shape) == 1 - # NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end + shard_param_list = set(shard_param_list) + if len(shard_param_list) == 0: + shard_param_list.add("weight") + shard_param_list.add("bias") + + def shard_param(param_name): if ( - self.share_param_list is not None - and layer.bias.name in self.share_param_list - and self.share_param_list[layer.bias.name] > 1 + hasattr(layer, param_name) + and getattr(layer, param_name) is not None ): - param_placements.update({"bias": placement}) - else: - layer.bias = dist.shard_tensor( - layer.bias, process_mesh, placement - ) - + layer_param = getattr(layer, param_name) + + if layer_param.is_dist(): + return + + if len(layer_param.shape) == 2: + placement[index] = dist.Shard(1) + elif len(layer_param.shape) == 1: + placement[index] = dist.Shard(0) + else: + raise ValueError(f"{layer_param} should have 1 or 2 dims.") + # NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end + if ( + self.share_param_list is not None + and layer_param.name in self.share_param_list + and self.share_param_list[layer_param.name] > 1 + ): + param_placements.update({param_name: placement}) + else: + layer_param = dist.shard_tensor( + layer_param, + process_mesh, + placement, + ) + setattr(layer, param_name, layer_param) + + for param_name in shard_param_list: + shard_param(param_name) if self.gather_output: layer.register_forward_post_hook( self.gather_output_hook(process_mesh) @@ -252,7 +255,7 @@ def split_hook(layer, input): return split_hook - def apply(self, layer, process_mesh, shard_weight=True, shard_bias=False): + def apply(self, layer, process_mesh, shard_param_list): index = process_mesh.dim_names.index('mp') # get the axis for the split size = len(process_mesh.shape) placement = [dist.Replicate() for _ in range(size)] @@ -265,25 +268,38 @@ def apply(self, layer, process_mesh, shard_weight=True, shard_bias=False): f"But got {layer.__class__.__name__}. " f"Will try to shard weight if the layer contains one." ) - if ( - hasattr(layer, "weight") - and layer.weight is not None - and shard_weight - ): - assert len(layer.weight.shape) == 2 - # NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end + shard_param_list = set(shard_param_list) + shard_param_list.discard("bias") + if len(shard_param_list) == 0: + shard_param_list.add("weight") + + def shard_param(param_name): if ( - self.share_param_list is not None - and layer.weight.name in self.share_param_list - and self.share_param_list[layer.weight.name] > 1 + hasattr(layer, param_name) + and getattr(layer, param_name) is not None ): - param_placements.update({"weight": placement}) - else: - layer.weight = dist.shard_tensor( - layer.weight, - process_mesh, - placement, - ) + layer_param = getattr(layer, param_name) + if layer_param.is_dist(): + return + if len(layer_param.shape) != 2: + raise ValueError(f"{layer_param} should have 2 dims.") + # NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end + if ( + self.share_param_list is not None + and layer_param.name in self.share_param_list + and self.share_param_list[layer_param.name] > 1 + ): + param_placements.update({param_name: placement}) + else: + layer_param = dist.shard_tensor( + layer_param, + process_mesh, + placement, + ) + setattr(layer, param_name, layer_param) + + for param_name in shard_param_list: + shard_param(param_name) if not self.is_input_parallel: layer.register_forward_pre_hook(self.split_input_hook(process_mesh)) return param_placements @@ -340,7 +356,7 @@ def __init__( assert callable(fn) self.fn = fn - def apply(self, layer, process_mesh, shard_weight=None, shard_bias=None): + def apply(self, layer, process_mesh, shard_param_list): layer.register_forward_pre_hook(self.fn(process_mesh=process_mesh)) @@ -395,7 +411,7 @@ def __init__( assert callable(fn) self.fn = fn - def apply(self, layer, process_mesh, shard_weight=None, shard_bias=None): + def apply(self, layer, process_mesh, shard_param_list): layer.register_forward_post_hook(self.fn(process_mesh=process_mesh)) @@ -445,7 +461,7 @@ def begin(layer, input, output): return begin - def apply(self, layer, process_mesh, shard_weight=None, shard_bias=None): + def apply(self, layer, process_mesh, shard_param_list): layer.register_forward_post_hook( self.sequence_parallel_begin(process_mesh) ) @@ -497,7 +513,7 @@ def end(layer, input, output=None): return end - def apply(self, layer, process_mesh, shard_weight=None, shard_bias=None): + def apply(self, layer, process_mesh, shard_param_list): layer.register_forward_pre_hook( self.sequence_parallel_end(process_mesh) ) @@ -547,7 +563,7 @@ def end(layer, input, output): return end - def apply(self, layer, process_mesh, shard_weight=None, shard_bias=None): + def apply(self, layer, process_mesh, shard_param_list): logging.warning( "Sequence parallel with the usage of SequenceParallel may not reach the best throughput. " "Try to use SequenceParallelBegin/End to achieve better performance" @@ -609,7 +625,7 @@ def end(layer, input, output=None): return end - def apply(self, layer, process_mesh, shard_weight=None, shard_bias=None): + def apply(self, layer, process_mesh, shard_param_list): layer.register_forward_pre_hook( self.sequence_parallel_end(process_mesh) ) @@ -639,22 +655,23 @@ def __init__(self, model, parallelize_plan=None): self.parallelize_plan = parallelize_plan self.tp_parallelizer = self.tensor_parallelizer_fn - def match_layer(self, name): + def match_layer(self, layer, name): # Match the layer to a plan. # Will return the plan if the layer hits one, otherwise return None. plans = [] for key, plan in self.parallelize_plan.items(): - shard_weight = True - shard_bias = True + attr_name = key.split('.')[-1] + shard_param_list = [] # Find some plan for specific parameter, such as # "lm_head.weight": ColWiseParallel() - # Only support weight or bias. - if key.endswith(".weight"): - key = key.replace(".weight", "") - shard_bias = False - elif key.endswith(".bias"): - key = key.replace(".bias", "") - shard_weight = False + # "qkv_porj.lora_A" ColWiseParallel() + # if there is no plan for specific parameter, layer will be sharded by default: layer.weight and layer.bias + if key.endswith(f".{attr_name}"): + if hasattr(layer, attr_name) and is_tensor( + getattr(layer, attr_name) + ): + key = key.replace(f".{attr_name}", "") + shard_param_list.append(attr_name) re_find = re.match(key, name) if key == name or ( re_find is not None @@ -662,7 +679,7 @@ def match_layer(self, name): ): if isinstance(plan, PlanBase): plan = [plan] - plans.append([plan, shard_weight, shard_bias]) + plans.append([plan, shard_param_list]) return plans def tensor_parallelizer_fn(self, model): @@ -678,19 +695,16 @@ def tensor_parallelizer_fn(self, model): continue share_param_list[param.name] += 1 for name, layer in model.named_sublayers(): - plans = self.match_layer(name) + plans = self.match_layer(layer, name) layer_param_placements[layer] = {} if len(plans) > 0: pp_idx = getattr(layer, "pipeline_stage_index", 0) for plan in plans: - real_plan, shard_weight, shard_bias = plan + real_plan, shard_param_list = plan for p in real_plan: p.share_param_list = share_param_list param_placements = p.apply( - layer, - self.get_mesh(pp_idx), - shard_weight, - shard_bias, + layer, self.get_mesh(pp_idx), shard_param_list ) if param_placements is not None and param_placements: layer_param_placements[layer].update( diff --git a/test/auto_parallel/hybrid_strategy/CMakeLists.txt b/test/auto_parallel/hybrid_strategy/CMakeLists.txt index 7720eb5443683..52e0129834890 100644 --- a/test/auto_parallel/hybrid_strategy/CMakeLists.txt +++ b/test/auto_parallel/hybrid_strategy/CMakeLists.txt @@ -161,3 +161,11 @@ if((WITH_GPU) AND (LINUX)) set_tests_properties(test_to_distributed_api_for_llama PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=HYBRID") endif() +if((WITH_GPU) AND (LINUX)) + py_test_modules( + test_parallel_api_with_llama_lora MODULES test_parallel_api_with_llama_lora + ENVS + "http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python") + set_tests_properties(test_parallel_api_with_llama_lora + PROPERTIES TIMEOUT "360" LABELS "RUN_TYPE=HYBRID") +endif() diff --git a/test/auto_parallel/hybrid_strategy/parallel_api.py b/test/auto_parallel/hybrid_strategy/parallel_api.py index db23e29ed8008..aace5ea47d222 100644 --- a/test/auto_parallel/hybrid_strategy/parallel_api.py +++ b/test/auto_parallel/hybrid_strategy/parallel_api.py @@ -14,10 +14,12 @@ import logging import os import random +from dataclasses import dataclass from functools import reduce import numpy as np from single_llama_model import LlamaForCausalLM, LlamaPretrainingCriterion +from single_lora_model import LoRAModel import paddle import paddle.distributed as dist @@ -57,6 +59,32 @@ class Config: use_lazy_init = False +@dataclass +class LoRaConfig: + r = 8 + lora_alpha = 8 + lora_dropout = 0.0 + rslora = False + lora_plus_scale = 1.0 + pissa = False + use_quick_lora = False + lora_use_mixer = False + use_mora = False + trainable_bias = False + trainable_modules = None + target_modules = [ + ".*q_proj.*", + ".*v_proj.*", + ".*k_proj.*", + ".*o_proj.*", + ".*qkv_proj.*", + ".*gate_proj.*", + ".*down_proj.*", + ".*up_proj.*", + ".*gate_up_fused_proj.*", + ] + + class RandomDataset(Dataset): def __init__(self, seq_len, num_samples=100): super().__init__() @@ -107,6 +135,7 @@ def apply_decay_param_fun(x): class TestParallelAPI: def __init__(self): self.config = Config() + self.lora_config = LoRaConfig() self.dp = int(os.getenv("dp")) self.mp = int(os.getenv("mp")) self.pp = int(os.getenv("pp")) @@ -143,6 +172,9 @@ def __init__(self): self.one_api = True seed = int(os.getenv("seed", 2024)) + self.share_embedding = int(os.getenv("test_share_embedding", "0")) + self.position_embedding = int(os.getenv("test_position_embedding", "0")) + self.test_lora = int(os.getenv("test_lora", "0")) np.random.seed(seed) random.seed(seed) paddle.seed(seed) @@ -160,7 +192,7 @@ def init_dist_env(self): global_mesh = dist.ProcessMesh(mesh_arr, dim_names) dist.auto_parallel.set_mesh(global_mesh) - def check_mp(self, layer, share_embedding): + def check_mp(self, layer): if self.mp == 1: return for name, sub_layer in layer.named_sublayers(): @@ -174,14 +206,24 @@ def check_mp(self, layer, share_embedding): dist.Replicate(), dist.Shard(0), ] + if self.test_lora: + assert sub_layer.lora_B.placements == [ + dist.Replicate(), + dist.Shard(1), + ] if 'gate_proj' in name or 'up_proj' in name: assert sub_layer.weight.placements == [ dist.Replicate(), dist.Shard(1), ] + if self.test_lora: + assert sub_layer.lora_B.placements == [ + dist.Replicate(), + dist.Shard(1), + ] if ( 'embed_tokens' in name or 'lm_head' in name - ) and not share_embedding: + ) and not self.share_embedding: assert sub_layer.weight.placements == [ dist.Replicate(), dist.Shard(1), @@ -190,94 +232,141 @@ def check_mp(self, layer, share_embedding): assert sub_layer.weight.placements == [ dist.Replicate(), dist.Shard(0), - ] + ], f'{name} , {sub_layer.weight.name} , {sub_layer.weight}' + if self.test_lora: + assert sub_layer.lora_A.placements == [ + dist.Replicate(), + dist.Shard(0), + ] # assert sub_layer.bias.placements is None if 'down_proj' in name: assert sub_layer.weight.placements == [ dist.Replicate(), dist.Shard(0), ] + if self.test_lora: + assert sub_layer.lora_A.placements == [ + dist.Replicate(), + dist.Shard(0), + ] + + def check_lora(self, layer): + if not self.test_lora: + return + for name, sub_layer in layer.named_sublayers(): + if len(sub_layer.sublayers()) == 0: + if 'q_proj' in name or 'k_proj' in name or 'v_proj' in name: + assert sub_layer.weight.stop_gradient + assert not sub_layer.lora_A.stop_gradient + assert not sub_layer.lora_B.stop_gradient + if 'gate_proj' in name or 'up_proj' in name: + assert sub_layer.weight.stop_gradient + assert not sub_layer.lora_A.stop_gradient + assert not sub_layer.lora_B.stop_gradient + if ( + 'embed_tokens' in name or 'lm_head' in name + ) and not self.share_embedding: + assert sub_layer.weight.stop_gradient + if 'o_proj' in name: + assert ( + sub_layer.weight.stop_gradient + ), f'{name} , {sub_layer.weight.name} , {sub_layer.weight}' + assert not sub_layer.lora_A.stop_gradient + assert not sub_layer.lora_B.stop_gradient + # assert sub_layer.bias.stop_gradient is None + if 'down_proj' in name: + assert sub_layer.weight.stop_gradient + assert not sub_layer.lora_A.stop_gradient + assert not sub_layer.lora_B.stop_gradient - def parallel_model(self, layer, share_embedding=False): + def parallel_model(self, layer): dp_config = None mp_config = None pp_config = None + prefix = "model." if self.test_lora else "" if self.pp > 1: # decoders_per_rank = self.config.num_hidden_layers // self.pp # split_spec = { - # f"llama.layers.{i * decoders_per_rank - 1}": SplitPoint.END + # ff"{prefix}llama.layers.{i * decoders_per_rank - 1}": SplitPoint.END # for i in range(1, self.pp) # } pp_config = { - 'split_spec': "llama.layers", - "global_spec": "llama.global_layer", + 'split_spec': f"{prefix}llama.layers", + "global_spec": f"{prefix}llama.global_layer", } if self.dp > 1: dp_config = {'sharding_level': self.level} if self.mp > 1: if not self.sequence_parallel: plan = { - "llama.embed_tokens": dist.ColWiseParallel( + f"{prefix}llama.embed_tokens": dist.ColWiseParallel( gather_output=True ), - "llama.position_embedding": dist.ColWiseParallel(), - "llama.layers.*.self_attn.q_proj": dist.ColWiseParallel( + f"{prefix}llama.position_embedding": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel( gather_output=True ), - "llama.layers.*.self_attn.k_proj": dist.ColWiseParallel( + f"{prefix}llama.layers.*.self_attn.q_proj.lora_B": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel( gather_output=True ), - "llama.layers.*.self_attn.v_proj": dist.ColWiseParallel( + f"{prefix}llama.layers.*.self_attn.k_proj.lora_B": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel( gather_output=True ), - "llama.layers.*.self_attn.o_proj": dist.RowWiseParallel( + f"{prefix}llama.layers.*.self_attn.v_proj.lora_B": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel( is_input_parallel=False ), - "llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(), - "llama.layers.*.mlp.up_proj": dist.ColWiseParallel(), - "llama.layers.*.mlp.down_proj": dist.RowWiseParallel(), - "lm_head.weight": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.o_proj.lora_A": dist.RowWiseParallel(), + f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.gate_proj.lora_B": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.up_proj.lora_B": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(), + f"{prefix}llama.layers.*.mlp.down_proj.lora_A": dist.RowWiseParallel(), + f"{prefix}lm_head.weight": dist.ColWiseParallel(), } else: if self.prepare_input_output: plan = { - "llama.embed_tokens": dist.ColWiseParallel(), - "llama.position_embedding": dist.ColWiseParallel(), - "llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(), - "llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(), - "llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(), - "llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(), - "llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(), - "llama.layers.*.mlp.up_proj": dist.ColWiseParallel(), - "llama.layers.*.mlp.down_proj": dist.RowWiseParallel(), - "lm_head.weight": dist.ColWiseParallel(), - "llama.layers.*.input_layernorm": dist.SequenceParallelEnable(), - "llama.layers.*.post_attention_layernorm": dist.SequenceParallelEnable(), - "llama.norm": dist.SequenceParallelEnable(), + f"{prefix}llama.embed_tokens": dist.ColWiseParallel(), + f"{prefix}llama.position_embedding": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(), + f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(), + f"{prefix}lm_head.weight": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.input_layernorm": dist.SequenceParallelEnable(), + f"{prefix}llama.layers.*.post_attention_layernorm": dist.SequenceParallelEnable(), + f"{prefix}llama.norm": dist.SequenceParallelEnable(), } else: plan = { - "llama.embed_tokens": [ + f"{prefix}llama.embed_tokens": [ dist.ColWiseParallel(), dist.SequenceParallelBegin(), ], - "llama.position_embedding": [ + f"{prefix}llama.position_embedding": [ dist.ColWiseParallel(), dist.SequenceParallelBegin(), ], - "llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(), - "llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(), - "llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(), - "llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(), - "llama.layers.*.self_attn": dist.SequenceParallelDisable(), - "llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(), - "llama.layers.*.mlp.up_proj": dist.ColWiseParallel(), - "llama.layers.*.mlp.down_proj": dist.RowWiseParallel(), - "llama.layers.*.mlp": dist.SequenceParallelDisable( + f"{prefix}llama.layers.*.self_attn.q_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.k_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.v_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.self_attn.o_proj": dist.RowWiseParallel(), + f"{prefix}llama.layers.*.self_attn": dist.SequenceParallelDisable(), + f"{prefix}llama.layers.*.mlp.gate_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.up_proj": dist.ColWiseParallel(), + f"{prefix}llama.layers.*.mlp.down_proj": dist.RowWiseParallel(), + f"{prefix}llama.layers.*.mlp": dist.SequenceParallelDisable( need_transpose=False ), - "lm_head.weight": dist.ColWiseParallel(), - "lm_head": dist.SequenceParallelEnd(), + f"{prefix}lm_head.weight": dist.ColWiseParallel(), + f"{prefix}lm_head": dist.SequenceParallelEnd(), } mp_config = {'parallelize_plan': plan} @@ -308,25 +397,27 @@ def parallel_model(self, layer, share_embedding=False): optimizer, config=config, ) - self.check_mp(layer, share_embedding) + self.check_mp(layer) + self.check_lora(layer) return layer, optimizer, lr_scheduler - def run_llama( - self, share_embedding=False, position_embedding=False, to_static=0 - ): + def run_llama(self, to_static=0): if self.config.use_lazy_init: with LazyGuard(): model = LlamaForCausalLM( - self.config, share_embedding, position_embedding + self.config, self.share_embedding, self.position_embedding ) else: model = LlamaForCausalLM( - self.config, share_embedding, position_embedding + self.config, self.share_embedding, self.position_embedding ) - - model, optimizer, lr_scheduler = self.parallel_model( - model, share_embedding - ) + if self.test_lora: + if self.config.use_lazy_init: + with LazyGuard(): + model = LoRAModel(model, self.lora_config) + else: + model = LoRAModel(model, self.lora_config) + model, optimizer, lr_scheduler = self.parallel_model(model) criterion = LlamaPretrainingCriterion(self.config) @@ -456,12 +547,10 @@ def run_llama( if step >= 3: break - def run_test_cases(self, share_embedding=False, position_embedding=False): - self.run_llama(share_embedding, position_embedding, 0) - self.run_llama(share_embedding, position_embedding, 1) + def run_test_cases(self): + self.run_llama(0) + self.run_llama(1) if __name__ == '__main__': - share_embedding = int(os.getenv("test_share_embedding", "0")) - position_embedding = int(os.getenv("test_position_embedding", "0")) - TestParallelAPI().run_test_cases(share_embedding, position_embedding) + TestParallelAPI().run_test_cases() diff --git a/test/auto_parallel/hybrid_strategy/single_lora_model.py b/test/auto_parallel/hybrid_strategy/single_lora_model.py new file mode 100644 index 0000000000000..b9580421e50b9 --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/single_lora_model.py @@ -0,0 +1,449 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import math +import re + +import numpy as np + +import paddle +import paddle.nn.functional as F +from paddle import nn + + +class LoRALinear(nn.Linear): + # LoRA implemented in a dense layer + def __init__( + self, + in_features: int, + out_features: int, + r: int = 0, + lora_alpha: int = 1, + lora_dropout: float = 0.0, + use_quick_lora: bool = False, + rslora: bool = False, + lora_plus_scale: float = 1.0, + pissa: bool = False, + lora_use_mixer: bool = False, + use_mora: bool = False, + **kwargs, + ): + nn.Linear.__init__(self, in_features, out_features, **kwargs) + if not isinstance(r, int) or r <= 0: + raise ValueError("Lora rank r should be a positive integer") + self.use_mora = use_mora + self.r = r + self.lora_alpha = lora_alpha + # Optional dropout + if lora_dropout > 0.0: + self.lora_dropout = nn.Dropout(p=lora_dropout) + else: + self.lora_dropout = lambda x: x + # Mark the weight as unmerged + self.merged = False + self.pissa = pissa + self.lora_use_mixer = lora_use_mixer + + # Actual trainable parameters + if use_mora: # reset the rank and create high rank matrix + self.in_features = in_features + self.out_features = out_features + new_r = int(math.sqrt((in_features + out_features) * r) + 0.5) + new_r = new_r // 2 * 2 + self.r = new_r + self.lora_A = self.create_parameter( + shape=[self.r, self.r], + dtype=self._dtype, + is_bias=False, + default_initializer=nn.initializer.Constant(value=0.0), + ) + self.cos = None + self.sin = None + # Count the number of tiles + self.rb1 = ( + self.in_features // self.r + if self.in_features % self.r == 0 + else self.in_features // self.r + 1 + ) + self.rb2 = ( + self.out_features // self.r + if self.out_features % self.r == 0 + else self.out_features // self.r + 1 + ) + self.rope_init() + else: + self.lora_A = self.create_parameter( + shape=[in_features, r], + dtype=self._dtype, + is_bias=False, + ) + if self.lora_use_mixer: + self.lora_AB = self.create_parameter( + shape=[r, r], + dtype=self._dtype, + is_bias=False, + ) + self.lora_B = self.create_parameter( + shape=[r, out_features], + dtype=self._dtype, + is_bias=False, + attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Constant(value=0.0), + learning_rate=lora_plus_scale, + ), + ) + self.apply_pissa = False + if use_mora or pissa: + self.scaling = 1.0 + elif not rslora: + self.scaling = self.lora_alpha / self.r + else: + self.scaling = self.lora_alpha / math.sqrt(self.r) + + # Freezing the pre-trained weight matrix + self.weight.stop_gradient = True + self._use_quick_lora = use_quick_lora and lora_dropout == 0.0 + self.disable_lora = False + + def pissa_init(self, rank): + weight = self.weight + dtype = weight.dtype + if dtype != paddle.float32: + weight = weight.astype(paddle.float32) + + U, S, Vh = paddle.linalg.svd(weight.data, full_matrices=False) + Ur = U[:, :rank] + Sr = S[:rank] + Vhr = Vh[:rank] + + lora_A = Ur @ paddle.diag(paddle.sqrt(Sr)) + lora_B = paddle.diag(paddle.sqrt(Sr)) @ Vhr + self.lora_A.set_value(lora_A.astype(dtype)) + self.lora_B.set_value(lora_B.astype(dtype)) + res = weight.data - lora_A @ lora_B + weight = res.astype(dtype) + self.weight.set_value(weight) + + def rope_init(self): + if self.cos is None or self.sin is None: + inv_freq = 1.0 / ( + 10000 + ** (paddle.arange(0, self.r, 2, dtype=paddle.float32) / self.r) + ) + t = paddle.arange(self.rb1, dtype=paddle.float32) + freqs = t.unsqueeze(1) @ inv_freq.unsqueeze(0) + emb = paddle.concat([freqs, freqs], axis=-1) + self.cos = paddle.unsqueeze(paddle.cos(emb), axis=0).astype( + self._dtype + ) + self.sin = paddle.unsqueeze(paddle.sin(emb), axis=0).astype( + self._dtype + ) + + @property + def use_quick_lora(self): + return self._use_quick_lora and self.training and not self.merged + + def _apply_mora(self, x): + r = self.r + + # Calculate grouping + sum_inter = self.in_features // r + + # padding + if self.in_features % r != 0: + pad_size = r - self.in_features % r + x = paddle.concat([x, x[..., :pad_size]], axis=-1) + sum_inter += 1 + + # reshape the input to apply RoPE + in_x = x.reshape([*x.shape[:-1], sum_inter, r]) + + # apply RoPE rotation + rh_in_x = paddle.concat( + [-in_x[..., r // 2 :], in_x[..., : r // 2]], axis=-1 + ) + in_x = in_x * self.cos + rh_in_x * self.sin + + # matmul with high rank matrix + out_x = in_x @ self.lora_A + + # reshape the output + out_x = out_x.reshape([*x.shape[:-1], -1])[..., : self.out_features] + if out_x.shape[-1] < self.out_features: + repeat_time = self.out_features // out_x.shape[-1] + if self.out_features % out_x.shape[-1] != 0: + repeat_time += 1 + out_x = paddle.concat([out_x] * repeat_time, axis=-1)[ + ..., : self.out_features + ] + + return out_x + + def get_delta_weight(self, lora_A=None, lora_B=None, lora_AB=None): + # compute the delta weight,which is used to merge weights + if self.lora_use_mixer: + lora_A = lora_A if lora_A is not None else self.lora_A + lora_B = lora_B if lora_B is not None else self.lora_B + lora_AB = lora_AB if lora_AB is not None else self.lora_AB + delta_weight = lora_A @ lora_AB @ lora_B * self.scaling + elif self.use_mora: + lora_A = lora_A if lora_A is not None else self.lora_A + r = self.r + # compute padding + pad_size = ( + r - self.in_features % r if self.in_features % r != 0 else 0 + ) + # initialize weights + w = paddle.zeros( + [self.in_features + pad_size, self.in_features], + dtype=lora_A.dtype, + ) + + # create the weights after rotation + aw2 = paddle.concat( + [lora_A[:, r // 2 :], -lora_A[:, : r // 2]], axis=-1 + ) + # apply RoPE + for i in range(self.rb1 - 1): + w[i * r : (i + 1) * r, i * r : (i + 1) * r] = ( + aw2 * self.sin[:, i] + lora_A * self.cos[:, i] + ) + # Process the last chunk that may be incomplete + i = self.rb1 - 1 + w[i * r :, i * r :] = ( + aw2 * self.sin[:, i] + lora_A * self.cos[:, i] + )[:, : r - pad_size] + # padding + if pad_size > 0: + w[i * r :, :pad_size] = ( + aw2 * self.sin[:, i] + lora_A * self.cos[:, i] + )[:, r - pad_size :] + # reshape the weights + if self.in_features < self.out_features: + w = paddle.concat([w] * self.rb2, axis=0)[: self.out_features] + else: + w = w[: self.out_features] + final_weight = w + delta_weight = final_weight.T + else: + lora_A = lora_A if lora_A is not None else self.lora_A + lora_B = lora_B if lora_B is not None else self.lora_B + delta_weight = lora_A @ lora_B * self.scaling + + return delta_weight + + def merge(self): + if not self.merged: + delta_weight = self.get_delta_weight() + new_weight = self.weight + delta_weight + self.weight.set_value(new_weight) + self.merged = True + + def unmerge(self): + if self.merged: + delta_weight = self.get_delta_weight() + new_weight = self.weight - delta_weight + self.weight.set_value(new_weight) + self.merged = False + + def forward(self, input: paddle.Tensor, *args, **kwargs): + if not self.apply_pissa and self.pissa: + self.pissa_init(self.r) + self.apply_pissa = True + if self.disable_lora or self.merged: + result = F.linear( + x=input, weight=self.weight, bias=self.bias, name=self.name + ) + elif self.use_mora: + result = F.linear( + x=input, weight=self.weight, bias=self.bias, name=self.name + ) + input = self.lora_dropout(input) + mora_out = self._apply_mora(input) + result += mora_out + else: + result = F.linear( + x=input, weight=self.weight, bias=self.bias, name=self.name + ) + if self.lora_use_mixer: + result += ( + self.lora_dropout(input) + @ self.lora_A + @ self.lora_AB + @ self.lora_B + ) * self.scaling + else: + result += ( + self.lora_dropout(input) @ self.lora_A @ self.lora_B + ) * self.scaling + return result + + def extra_repr(self): + name = f", name={self.name}" if self.name else "" + return f"in_features={self.weight.shape[0]}, out_features={self.weight.shape[1]}, rank={self.r}{name}" + + +lora_layers = { + "LoRALinear": LoRALinear, +} +LoRALinear = lora_layers["LoRALinear"] +AVAILABLE_LAYERS = [ + LoRALinear, +] + + +class LoRAModel(nn.Layer): + + def __init__(self, model, lora_config) -> None: + super().__init__() + self.model = self.get_lora_model(model, lora_config) + + self.lora_config = lora_config + logging.info("Mark only lora and trainable_module as trainable.") + self.mark_only_lora_as_trainable() + + def forward(self, input_ids): + return self.model(input_ids) + + def _find_and_replace_module(self, model, module_name, lora_config): + parent_module = model + attribute_chain = module_name.split(".") + for name in attribute_chain[:-1]: + parent_module = getattr(parent_module, name) + module = getattr(parent_module, attribute_chain[-1]) + lora_module = None + if isinstance(module, nn.Linear): + lora_module = LoRALinear( + in_features=module.weight.shape[0], + out_features=module.weight.shape[1], + r=lora_config.r, + lora_alpha=lora_config.lora_alpha, + lora_dropout=lora_config.lora_dropout, + rslora=lora_config.rslora, + lora_plus_scale=lora_config.lora_plus_scale, + pissa=lora_config.pissa, + bias_attr=False if module.bias is None else None, + use_quick_lora=lora_config.use_quick_lora, + lora_use_mixer=lora_config.lora_use_mixer, + use_mora=lora_config.use_mora, + ) + if lora_module is None: + raise ValueError( + f"LoRA strategy only supports paddle.nn.Linear or paddle.distributed.fleet.meta_parallel.ColumnParallelLinear or paddlenlp.transformers.sequence_utils. {module}({module_name} {type(module).__name__}) is not supported。" + ) + lora_module.weight = module.weight + if module.bias is not None: + lora_module.bias = module.bias + setattr(parent_module, attribute_chain[-1], lora_module) + + def print_trainable_parameters(self) -> None: + freeze_numel = 0 + trainable_numel = 0 + for _, weight in self.model.state_dict().items(): + if weight.stop_gradient: + freeze_numel += np.prod(weight.shape) + else: + trainable_numel += np.prod(weight.shape) + logging.debug( + f"Frozen parameters: {freeze_numel:.2e} || Trainable parameters:{trainable_numel:.2e} || Total parameters:{freeze_numel + trainable_numel:.2e}|| Trainable:{trainable_numel / (freeze_numel + trainable_numel):.2%}" + ) + + def mark_only_lora_as_trainable(self) -> None: + for _, layer in self.model.named_sublayers(): + if isinstance(layer, LoRALinear): + for name, weight in layer.state_dict().items(): + if ( + self.lora_config.trainable_bias in ["lora", "all"] + and "bias" in name + ): + weight.stop_gradient = False + elif "lora" in name: + weight.stop_gradient = False + else: + weight.stop_gradient = True + else: + for name, weight in layer.state_dict().items(): + if ( + self.lora_config.trainable_bias == "all" + and "bias" in name + ): + weight.stop_gradient = False + else: + weight.stop_gradient = True + if self.lora_config.trainable_modules is not None: + for name, weight in self.model.state_dict().items(): + if any( + re.fullmatch(trainable_module, name) + for trainable_module in self.lora_config.trainable_modules + ): + weight.stop_gradient = False + + def get_lora_model(self, model, lora_config): + if lora_config.target_modules is None: + return model + elif isinstance(lora_config.target_modules, str): + target_modules = [lora_config.target_modules] + else: + target_modules = lora_config.target_modules + for target_module in target_modules: + for i in model.named_sublayers(): + module_name = i[0] + if re.fullmatch(target_module, module_name): + self._find_and_replace_module( + model, module_name, lora_config + ) + return model + + def train(self): + self.training = True + self.model.training = True + for layer in self.model.sublayers(): + layer.training = True + layer.train() + + def eval(self): + self.training = False + self.model.training = False + for layer in self.model.sublayers(): + layer.training = False + layer.eval() + + def disable_lora(self): + for _, layer in self.model.named_sublayers(): + if any( + isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS + ): + layer.disable_lora = True + + def enable_lora(self): + for _, layer in self.model.named_sublayers(): + if any( + isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS + ): + layer.disable_lora = False + + def merge(self): + for _, layer in self.model.named_sublayers(): + if any( + isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS + ): + layer.merge() + + def unmerge(self): + for _, layer in self.model.named_sublayers(): + if any( + isinstance(layer, lora_layer) for lora_layer in AVAILABLE_LAYERS + ): + layer.unmerge() diff --git a/test/auto_parallel/hybrid_strategy/test_parallel_api_with_llama_lora.py b/test/auto_parallel/hybrid_strategy/test_parallel_api_with_llama_lora.py new file mode 100644 index 0000000000000..34283c363d67a --- /dev/null +++ b/test/auto_parallel/hybrid_strategy/test_parallel_api_with_llama_lora.py @@ -0,0 +1,67 @@ +# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import tempfile +import unittest + +import collective.test_communication_api_base as test_base + + +class TestDPMPPPAPI(test_base.CommunicationTestDistBase): + def setUp(self): + super().setUp(num_of_devices=8, timeout=360, nnode=1) + self._default_envs = { + "dtype": "float32", + "seed": "2023", + "dp": "2", + "mp": "2", + "pp": "2", + "acc_step": "2", + } + self._changeable_envs = { + "backend": ["gpu"], + "amp": ["true"], + "amp_level": ["O2"], + "amp_dtype": ["bfloat16"], + "amp_master_grad": ["true"], + "use_lazy_init": ["false"], + "sequence_parallel": ["false"], + "prepare_input_output": ["false"], + "sharding_stage": ["0", "1"], + "test_share_embedding": [ + "0", + ], + "test_position_embedding": [ + "1", + ], + "one_api": ["true", "false"], + "test_lora": ["1"], + } + + def test_simple_lora_net_dp2_mp2_pp2(self): + envs_list = test_base.gen_product_envs_list( + self._default_envs, self._changeable_envs + ) + for envs in envs_list: + ckpt_path = tempfile.TemporaryDirectory() + envs["ckpt_path"] = ckpt_path.name + self.run_test_case( + "parallel_api.py", + user_defined_envs=envs, + ) + ckpt_path.cleanup() + + +if __name__ == "__main__": + unittest.main() diff --git a/test/auto_parallel/hybrid_strategy/testslist.csv b/test/auto_parallel/hybrid_strategy/testslist.csv index 07a777711fad7..f38b29bafe264 100644 --- a/test/auto_parallel/hybrid_strategy/testslist.csv +++ b/test/auto_parallel/hybrid_strategy/testslist.csv @@ -18,3 +18,4 @@ test_parallel_api_with_llama_1d,LINUX,GPU,400,HYBRID,test_runner.py,,,http_proxy test_parallel_api_with_llama_2d,LINUX,GPU,400,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_parallel_api_with_llama_3d,LINUX,GPU,400,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., test_to_distributed_api_for_llama,LINUX,GPU,180,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../.., +test_parallel_api_with_llama_lora,LINUX,GPU,360,HYBRID,test_runner.py,,,http_proxy=;https_proxy=;PYTHONPATH=../..,