Skip to content

Commit

Permalink
Modernize DDP tests
Browse files Browse the repository at this point in the history
Summary: These tests are timing out on internal testing infra.  modernizing to resolve.

Differential Revision: D67720829
  • Loading branch information
dstaay-fb authored and facebook-github-bot committed Dec 30, 2024
1 parent f059a49 commit a091957
Showing 1 changed file with 129 additions and 174 deletions.
303 changes: 129 additions & 174 deletions torchrec/distributed/composable/tests/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,10 @@

#!/usr/bin/env python3

import os
import tempfile
import unittest
import uuid

import torch
from torch import distributed as dist
from torch.distributed._composable import replicate
from torch.distributed._shard.api import ShardedTensor
from torch.distributed.checkpoint import (
Expand All @@ -24,167 +21,142 @@
load_state_dict,
save_state_dict,
)
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from torchrec.distributed.shard import shard as trec_shard, shard_modules
from torchrec.distributed.sharding_plan import column_wise
from torchrec.distributed.test_utils.multi_process import (
MultiProcessContext,
MultiProcessTestBase,
)
from torchrec.distributed.test_utils.test_model import ModelInput, TestSparseNN
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.test_utils import skip_if_asan


class DDPTest(unittest.TestCase):
class DDPTest(MultiProcessTestBase):
@classmethod
def _run_init_parameters(cls, path: str) -> None:
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if torch.cuda.is_available():
device: torch.device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device: torch.device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(
backend=backend,
rank=rank,
world_size=world_size,
init_method=f"file://{os.path.join(path, 'dist_rdvz')}",
)
num_float_features = 32

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(3)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(2)
]
m = TestSparseNN(
tables=tables,
num_float_features=num_float_features,
weighted_tables=weighted_tables,
dense_device=device,
)
# Put all tensors on meta device, then init_params should
# materialize them.
for name, param in m._parameters.items():
if isinstance(param, torch.Tensor):
m._parameters[name] = torch.nn.Parameter(
torch.empty_like(param, device="meta"),
requires_grad=param.requires_grad,
def _run_init(cls, rank: int, world_size: int) -> None:
with MultiProcessContext(rank, world_size, "nccl") as ctx:
num_float_features = 32

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)

shard_modules(m, device=device, init_params=True)
# init_params should move m to `device`
for p in m.parameters():
assert p.device == device
for i in range(3)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(2)
]
m = TestSparseNN(
tables=tables,
num_float_features=num_float_features,
weighted_tables=weighted_tables,
dense_device=ctx.device,
)
# Put all tensors on meta device, then init_params should
# materialize them.
for name, param in m._parameters.items():
if isinstance(param, torch.Tensor):
m._parameters[name] = torch.nn.Parameter(
torch.empty_like(param, device="meta"),
requires_grad=param.requires_grad,
)

shard_modules(m, device=ctx.device, init_params=True)
# init_params should move m to `device`
for p in m.parameters():
assert p.device == ctx.device

@classmethod
def _run(cls, path: str) -> None:
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if torch.cuda.is_available():
device: torch.device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device: torch.device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(
backend=backend,
rank=rank,
world_size=world_size,
init_method=f"file://{os.path.join(path, 'dist_rdvz')}",
)
num_float_features = 32

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
def _run(cls, rank: int, world_size: int, path: str) -> None:
with MultiProcessContext(rank, world_size, "nccl") as ctx:
num_float_features = 32

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(3)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(2)
]
m = TestSparseNN(
tables=tables,
num_float_features=num_float_features,
weighted_tables=weighted_tables,
dense_device=ctx.device,
)
for i in range(3)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
m.sparse.ebc = trec_shard(
module=m.sparse.ebc,
device=ctx.device,
plan=column_wise(ranks=list(range(world_size))),
)
for i in range(2)
]
m = TestSparseNN(
tables=tables,
num_float_features=num_float_features,
weighted_tables=weighted_tables,
dense_device=device,
)
m.sparse.ebc = trec_shard(
module=m.sparse.ebc,
device=device,
plan=column_wise(ranks=list(range(world_size))),
)
m.sparse.weighted_ebc = trec_shard(
module=m.sparse.weighted_ebc,
device=device,
plan=column_wise(ranks=list(range(world_size))),
)
m.over = replicate(m.over)
m.dense = replicate(m.dense)

######## run one iteration ########
_, local_batch = ModelInput.generate(
batch_size=8,
world_size=world_size,
num_float_features=num_float_features,
tables=tables,
weighted_tables=weighted_tables,
)
batch = local_batch[0].to(device)
m(batch)[1].sum().backward()

state_dict = m.state_dict()
writer = FileSystemWriter(path=path)
reader = FileSystemReader(path=path)
save_state_dict(state_dict, writer)

p_sum = torch.zeros(1, device=device)
for p in m.parameters():
with torch.no_grad():
if isinstance(p, ShardedTensor):
if not p.local_shards():
continue
p = p.local_tensor()
p_sum += p.sum()
p.zero_()
assert p.sum() == 0
load_state_dict(state_dict, reader)
m.load_state_dict(state_dict)

p_sum_loaded = torch.zeros(1, device=device)
for p in m.parameters():
with torch.no_grad():
if isinstance(p, ShardedTensor):
if not p.local_shards():
continue
p = p.local_tensor()
p_sum_loaded += p.sum()
# TODO: debug why failing on OSS
# assert p_sum.allclose(p_sum_loaded)
m.sparse.weighted_ebc = trec_shard(
module=m.sparse.weighted_ebc,
device=ctx.device,
plan=column_wise(ranks=list(range(world_size))),
)
m.over = replicate(m.over)
m.dense = replicate(m.dense)

######## run one iteration ########
_, local_batch = ModelInput.generate(
batch_size=8,
world_size=world_size,
num_float_features=num_float_features,
tables=tables,
weighted_tables=weighted_tables,
)
batch = local_batch[0].to(ctx.device)
m(batch)[1].sum().backward()

state_dict = m.state_dict()
writer = FileSystemWriter(path=path)
reader = FileSystemReader(path=path)
save_state_dict(state_dict, writer)

p_sum = torch.zeros(1, device=ctx.device)
for p in m.parameters():
with torch.no_grad():
if isinstance(p, ShardedTensor):
if not p.local_shards():
continue
p = p.local_tensor()
p_sum += p.sum()
p.zero_()
assert p.sum() == 0
load_state_dict(state_dict, reader)
m.load_state_dict(state_dict)

p_sum_loaded = torch.zeros(1, device=ctx.device)
for p in m.parameters():
with torch.no_grad():
if isinstance(p, ShardedTensor):
if not p.local_shards():
continue
p = p.local_tensor()
p_sum_loaded += p.sum()
# TODO: debug why failing on OSS
# assert p_sum.allclose(p_sum_loaded)

@skip_if_asan
# pyre-fixme[56]: Pyre was not able to infer the type of argument
Expand All @@ -195,18 +167,10 @@ def _run(cls, path: str) -> None:
)
def test_checkpoint(self) -> None:
with tempfile.TemporaryDirectory() as path:
lc = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=2,
run_id=str(uuid.uuid4()),
rdzv_backend="c10d",
rdzv_endpoint="localhost:0",
start_method="spawn",
monitor_interval=1,
max_restarts=0,
self._run_multi_process_test(
callable=self._run,
path=path,
)
elastic_launch(config=lc, entrypoint=self._run)(path)

@skip_if_asan
# pyre-fixme[56]: Pyre was not able to infer the type of argument
Expand All @@ -216,15 +180,6 @@ def test_checkpoint(self) -> None:
"Not enough GPUs, this test requires at least two GPUs",
)
def test_init_params(self) -> None:
with tempfile.TemporaryDirectory() as path:
lc = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=2,
run_id=str(uuid.uuid4()),
rdzv_backend="c10d",
start_method="spawn",
monitor_interval=1,
max_restarts=0,
)
elastic_launch(config=lc, entrypoint=self._run_init_parameters)(path)
self._run_multi_process_test(
callable=self._run_init,
)

0 comments on commit a091957

Please sign in to comment.