Skip to content

Commit

Permalink
Pipeline parallel (#40)
Browse files Browse the repository at this point in the history
Signed-off-by: Mayank Mishra <[email protected]>
  • Loading branch information
mayank31398 authored Oct 29, 2024
1 parent b7949d2 commit ee7ed18
Show file tree
Hide file tree
Showing 56 changed files with 1,495 additions and 615 deletions.
2 changes: 1 addition & 1 deletion configs/ultra-long-context-length/3b-131k-training.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ mixed_precision_args:
dtype: bf16

distributed_args:
tensor_parallel_size: 8
tensor_parallel_world_size: 8
fsdp_algorithm: 2
sequence_parallel: true
tensor_parallel_word_embeddings: true
Expand Down
2 changes: 1 addition & 1 deletion configs/ultra-long-context-length/3b-65k-training.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ mixed_precision_args:
dtype: bf16

distributed_args:
tensor_parallel_size: 8
tensor_parallel_world_size: 8
fsdp_algorithm: 2
sequence_parallel: true
tensor_parallel_word_embeddings: true
Expand Down
2 changes: 1 addition & 1 deletion configs/ultra-long-context-length/8b-65k-training.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ mixed_precision_args:
dtype: bf16

distributed_args:
tensor_parallel_size: 8
tensor_parallel_world_size: 8
fsdp_algorithm: 2
sequence_parallel: true
tensor_parallel_word_embeddings: true
Expand Down
24 changes: 20 additions & 4 deletions dolomite_engine/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,11 +324,13 @@ class DistributedArgs(BaseArgs):
# whether to use a dispatching dataloader
dispatching_dataloader: bool = False
# tensor parallel world size
tensor_parallel_size: int = 1
tensor_parallel_world_size: int = 1
# tensor parallel embeddings
tensor_parallel_word_embeddings: bool = False
# whether to use sequence parallel
sequence_parallel: bool = False
# pipeline parallel world size
pipeline_parallel_world_size: int = 1
# data parallel world size
data_parallel_size: int | None = None
# distributed timeout for NCCL in minutes
Expand All @@ -337,6 +339,10 @@ class DistributedArgs(BaseArgs):
fsdp_algorithm: int = 1
# whether to sync every gradient accumulation step
sync_every_gradient_accumulation_step: bool = False
# total number of pipeline stages
num_pipeline_stages: int = 1
# pipeline parallel shedule to use
pipeline_parallel_schedule: str | None = None
# whether to use async-TP
use_async_tensor_parallel: bool = False

Expand All @@ -346,14 +352,14 @@ def model_post_init(self, __context: Any) -> None:
self.communication_dtype = normalize_dtype_string(self.communication_dtype)

if self.sequence_parallel:
assert self.tensor_parallel_size > 1, "tensor parallel needs to be enabled for sequence parallel"
assert self.tensor_parallel_world_size > 1, "tensor parallel needs to be enabled for sequence parallel"

if self.tensor_parallel_word_embeddings:
assert (
self.tensor_parallel_size > 1
self.tensor_parallel_world_size > 1
), "tensor parallel needs to be enabled when using tensor parallel work embeddings"

if self.tensor_parallel_size > 1:
if self.tensor_parallel_world_size > 1:
version = Version(torch.__version__).release
version = [str(i) for i in version]
version = ".".join(version)
Expand All @@ -369,6 +375,13 @@ def model_post_init(self, __context: Any) -> None:
if self.use_async_tensor_parallel:
assert self.sequence_parallel, "sequence parallel should be enabled for using async-TP"

assert (
self.num_pipeline_stages % self.pipeline_parallel_world_size == 0
), "num_pipeline_stages should be a multiple of pipeline_parallel_world_size"

if self.num_pipeline_stages > 1:
_check_not_None([(self.pipeline_parallel_schedule, "pipeline_parallel_schedule")])


class AimArgs(BaseArgs):
# aim repo, experiment logs are saved here
Expand Down Expand Up @@ -491,6 +504,9 @@ def model_post_init(self, __context: Any) -> None:
# datasets
_check_datasets(self.datasets)

if self.distributed_args.num_pipeline_stages > 1 and self.training_parameters.eval_during_training:
raise NotImplementedError("evaluation is not supported with pipeline parallel")


class GenerationParameters(BaseArgs):
# batch size
Expand Down
Loading

0 comments on commit ee7ed18

Please sign in to comment.