From ee43118e23d797c026ea49c78f801103e40d711f Mon Sep 17 00:00:00 2001 From: Sami Jaghouar Date: Sun, 22 Sep 2024 23:42:44 +0000 Subject: [PATCH] refactor: remode ddp func --- src/zeroband/train.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/src/zeroband/train.py b/src/zeroband/train.py index 5bb4243d..39cae775 100644 --- a/src/zeroband/train.py +++ b/src/zeroband/train.py @@ -29,14 +29,6 @@ from zeroband.utils.logging import get_logger -def ddp_setup(): - """ - Initialize the distributed process group. - """ - init_process_group() - torch.cuda.set_device(world_info.local_rank) - - class DataConfig(BaseConfig): seq_length: int = 1024 fake: bool = False @@ -255,7 +247,8 @@ def train(config: Config): world_info = get_world_info() logger = get_logger() - ddp_setup() + init_process_group() + torch.cuda.set_device(world_info.local_rank) config = Config(**parse_argv()) logger.debug(f"config: {config.model_dump()}")