Skip to content

Commit

Permalink
refactor: remode ddp func
Browse files Browse the repository at this point in the history
  • Loading branch information
samsja committed Sep 22, 2024
1 parent 3b1f740 commit ee43118
Showing 1 changed file with 2 additions and 9 deletions.
11 changes: 2 additions & 9 deletions src/zeroband/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,6 @@
from zeroband.utils.logging import get_logger


def ddp_setup():
"""
Initialize the distributed process group.
"""
init_process_group()
torch.cuda.set_device(world_info.local_rank)


class DataConfig(BaseConfig):
seq_length: int = 1024
fake: bool = False
Expand Down Expand Up @@ -255,7 +247,8 @@ def train(config: Config):
world_info = get_world_info()
logger = get_logger()

ddp_setup()
init_process_group()
torch.cuda.set_device(world_info.local_rank)

config = Config(**parse_argv())
logger.debug(f"config: {config.model_dump()}")
Expand Down

0 comments on commit ee43118

Please sign in to comment.