Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[AutoParallel]:intermediate api supports LoRA #70539

Merged
merged 9 commits into from
Jan 10, 2025
1 change: 1 addition & 0 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,7 @@ def _init_func(var, block):
placements=placements,
**tensor.__dict__,
)
dist_param.stop_gradient = tensor.stop_gradient
if tensor._init_func is not None:
origin_init_func = tensor._init_func
dist_param.set_init_func(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ class PlanBase:
def __init__(self):
self.share_param_list = {}

def apply(self, layer, process_mesh, shard_weight, shard_bias):
def apply(self, layer, process_mesh, shard_param_list):
raise NotImplementedError("Don't call the PlanBase directly.")


Expand Down Expand Up @@ -151,7 +151,7 @@ def gather_hook(layer, input, output):

return gather_hook

def apply(self, layer, process_mesh, shard_weight=True, shard_bias=True):
def apply(self, layer, process_mesh, shard_param_list):
index = process_mesh.dim_names.index('mp') # get the axis for the split
size = len(process_mesh.shape)
placement = [dist.Replicate() for _ in range(size)]
Expand All @@ -163,41 +163,44 @@ def apply(self, layer, process_mesh, shard_weight=True, shard_bias=True):
f"But got {layer.__class__.__name__}. "
f"Will try to shard weight and bias if the layer contains one."
)
if (
hasattr(layer, "weight")
and layer.weight is not None
and shard_weight
):
placement[index] = dist.Shard(1)
assert len(layer.weight.shape) == 2
# NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end
if (
self.share_param_list is not None
and layer.weight.name in self.share_param_list
and self.share_param_list[layer.weight.name] > 1
):
param_placements.update({"weight": placement})
else:
layer.weight = dist.shard_tensor(
layer.weight,
process_mesh,
placement,
)
if hasattr(layer, "bias") and layer.bias is not None and shard_bias:
placement[index] = dist.Shard(0)
assert len(layer.bias.shape) == 1
# NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end
shard_param_list = set(shard_param_list)
if len(shard_param_list) == 0:
shard_param_list.add("weight")
shard_param_list.add("bias")

def shard_param(param_name):
if (
self.share_param_list is not None
and layer.bias.name in self.share_param_list
and self.share_param_list[layer.bias.name] > 1
hasattr(layer, param_name)
and getattr(layer, param_name) is not None
):
param_placements.update({"bias": placement})
else:
layer.bias = dist.shard_tensor(
layer.bias, process_mesh, placement
)

layer_param = getattr(layer, param_name)

if layer_param.is_dist():
return

if len(layer_param.shape) == 2:
placement[index] = dist.Shard(1)
elif len(layer_param.shape) == 1:
placement[index] = dist.Shard(0)
else:
raise ValueError(f"{layer_param} should have 1 or 2 dims.")
# NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end
if (
self.share_param_list is not None
and layer_param.name in self.share_param_list
and self.share_param_list[layer_param.name] > 1
):
param_placements.update({param_name: placement})
else:
layer_param = dist.shard_tensor(
layer_param,
process_mesh,
placement,
)
setattr(layer, param_name, layer_param)

for param_name in shard_param_list:
shard_param(param_name)
if self.gather_output:
layer.register_forward_post_hook(
self.gather_output_hook(process_mesh)
Expand Down Expand Up @@ -252,7 +255,7 @@ def split_hook(layer, input):

return split_hook

def apply(self, layer, process_mesh, shard_weight=True, shard_bias=False):
def apply(self, layer, process_mesh, shard_param_list):
index = process_mesh.dim_names.index('mp') # get the axis for the split
size = len(process_mesh.shape)
placement = [dist.Replicate() for _ in range(size)]
Expand All @@ -265,25 +268,38 @@ def apply(self, layer, process_mesh, shard_weight=True, shard_bias=False):
f"But got {layer.__class__.__name__}. "
f"Will try to shard weight if the layer contains one."
)
if (
hasattr(layer, "weight")
and layer.weight is not None
and shard_weight
):
assert len(layer.weight.shape) == 2
# NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end
shard_param_list = set(shard_param_list)
shard_param_list.discard("bias")
if len(shard_param_list) == 0:
shard_param_list.add("weight")

def shard_param(param_name):
if (
self.share_param_list is not None
and layer.weight.name in self.share_param_list
and self.share_param_list[layer.weight.name] > 1
hasattr(layer, param_name)
and getattr(layer, param_name) is not None
):
param_placements.update({"weight": placement})
else:
layer.weight = dist.shard_tensor(
layer.weight,
process_mesh,
placement,
)
layer_param = getattr(layer, param_name)
if layer_param.is_dist():
return
if len(layer_param.shape) != 2:
raise ValueError(f"{layer_param} should have 2 dims.")
# NOTE(zhangweilong):for share parameter, the parameter should be handled uniformly in the end
if (
self.share_param_list is not None
and layer_param.name in self.share_param_list
and self.share_param_list[layer_param.name] > 1
):
param_placements.update({param_name: placement})
else:
layer_param = dist.shard_tensor(
layer_param,
process_mesh,
placement,
)
setattr(layer, param_name, layer_param)

for param_name in shard_param_list:
shard_param(param_name)
if not self.is_input_parallel:
layer.register_forward_pre_hook(self.split_input_hook(process_mesh))
return param_placements
Expand Down Expand Up @@ -340,7 +356,7 @@ def __init__(
assert callable(fn)
self.fn = fn

def apply(self, layer, process_mesh, shard_weight=None, shard_bias=None):
def apply(self, layer, process_mesh, shard_param_list):
layer.register_forward_pre_hook(self.fn(process_mesh=process_mesh))


Expand Down Expand Up @@ -395,7 +411,7 @@ def __init__(
assert callable(fn)
self.fn = fn

def apply(self, layer, process_mesh, shard_weight=None, shard_bias=None):
def apply(self, layer, process_mesh, shard_param_list):
layer.register_forward_post_hook(self.fn(process_mesh=process_mesh))


Expand Down Expand Up @@ -445,7 +461,7 @@ def begin(layer, input, output):

return begin

def apply(self, layer, process_mesh, shard_weight=None, shard_bias=None):
def apply(self, layer, process_mesh, shard_param_list):
layer.register_forward_post_hook(
self.sequence_parallel_begin(process_mesh)
)
Expand Down Expand Up @@ -497,7 +513,7 @@ def end(layer, input, output=None):

return end

def apply(self, layer, process_mesh, shard_weight=None, shard_bias=None):
def apply(self, layer, process_mesh, shard_param_list):
layer.register_forward_pre_hook(
self.sequence_parallel_end(process_mesh)
)
Expand Down Expand Up @@ -547,7 +563,7 @@ def end(layer, input, output):

return end

def apply(self, layer, process_mesh, shard_weight=None, shard_bias=None):
def apply(self, layer, process_mesh, shard_param_list):
logging.warning(
"Sequence parallel with the usage of SequenceParallel may not reach the best throughput. "
"Try to use SequenceParallelBegin/End to achieve better performance"
Expand Down Expand Up @@ -609,7 +625,7 @@ def end(layer, input, output=None):

return end

def apply(self, layer, process_mesh, shard_weight=None, shard_bias=None):
def apply(self, layer, process_mesh, shard_param_list):
layer.register_forward_pre_hook(
self.sequence_parallel_end(process_mesh)
)
Expand Down Expand Up @@ -639,30 +655,31 @@ def __init__(self, model, parallelize_plan=None):
self.parallelize_plan = parallelize_plan
self.tp_parallelizer = self.tensor_parallelizer_fn

def match_layer(self, name):
def match_layer(self, layer, name):
# Match the layer to a plan.
# Will return the plan if the layer hits one, otherwise return None.
plans = []
for key, plan in self.parallelize_plan.items():
shard_weight = True
shard_bias = True
attr_name = key.split('.')[-1]
shard_param_list = []
# Find some plan for specific parameter, such as
# "lm_head.weight": ColWiseParallel()
# Only support weight or bias.
if key.endswith(".weight"):
key = key.replace(".weight", "")
shard_bias = False
elif key.endswith(".bias"):
key = key.replace(".bias", "")
shard_weight = False
# "qkv_porj.lora_A" ColWiseParallel()
# if there is no plan for specific parameter, layer will be sharded by default: layer.weight and layer.bias
if key.endswith(f".{attr_name}"):
if hasattr(layer, attr_name) and is_tensor(
getattr(layer, attr_name)
):
key = key.replace(f".{attr_name}", "")
shard_param_list.append(attr_name)
re_find = re.match(key, name)
if key == name or (
re_find is not None
and int(re_find.end()) - int(re_find.start()) == len(name)
):
if isinstance(plan, PlanBase):
plan = [plan]
plans.append([plan, shard_weight, shard_bias])
plans.append([plan, shard_param_list])
return plans

def tensor_parallelizer_fn(self, model):
Expand All @@ -678,19 +695,16 @@ def tensor_parallelizer_fn(self, model):
continue
share_param_list[param.name] += 1
for name, layer in model.named_sublayers():
plans = self.match_layer(name)
plans = self.match_layer(layer, name)
layer_param_placements[layer] = {}
if len(plans) > 0:
pp_idx = getattr(layer, "pipeline_stage_index", 0)
for plan in plans:
real_plan, shard_weight, shard_bias = plan
real_plan, shard_param_list = plan
for p in real_plan:
p.share_param_list = share_param_list
param_placements = p.apply(
layer,
self.get_mesh(pp_idx),
shard_weight,
shard_bias,
layer, self.get_mesh(pp_idx), shard_param_list
)
if param_placements is not None and param_placements:
layer_param_placements[layer].update(
Expand Down
8 changes: 8 additions & 0 deletions test/auto_parallel/hybrid_strategy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,11 @@ if((WITH_GPU) AND (LINUX))
set_tests_properties(test_to_distributed_api_for_llama
PROPERTIES TIMEOUT "180" LABELS "RUN_TYPE=HYBRID")
endif()
if((WITH_GPU) AND (LINUX))
py_test_modules(
test_parallel_api_with_llama_lora MODULES test_parallel_api_with_llama_lora
ENVS
"http_proxy=;https_proxy=;PYTHONPATH=../..:${PADDLE_BINARY_DIR}/python")
set_tests_properties(test_parallel_api_with_llama_lora
PROPERTIES TIMEOUT "360" LABELS "RUN_TYPE=HYBRID")
endif()
Loading
Loading