Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modernize DDP tests #2658

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
303 changes: 129 additions & 174 deletions torchrec/distributed/composable/tests/test_ddp.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,10 @@

#!/usr/bin/env python3

import os
import tempfile
import unittest
import uuid

import torch
from torch import distributed as dist
from torch.distributed._composable import replicate
from torch.distributed._shard.api import ShardedTensor
from torch.distributed.checkpoint import (
Expand All @@ -24,167 +21,142 @@
load_state_dict,
save_state_dict,
)
from torch.distributed.launcher.api import elastic_launch, LaunchConfig
from torchrec.distributed.shard import shard as trec_shard, shard_modules
from torchrec.distributed.sharding_plan import column_wise
from torchrec.distributed.test_utils.multi_process import (
MultiProcessContext,
MultiProcessTestBase,
)
from torchrec.distributed.test_utils.test_model import ModelInput, TestSparseNN
from torchrec.modules.embedding_configs import EmbeddingBagConfig
from torchrec.test_utils import skip_if_asan


class DDPTest(unittest.TestCase):
class DDPTest(MultiProcessTestBase):
@classmethod
def _run_init_parameters(cls, path: str) -> None:
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if torch.cuda.is_available():
device: torch.device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device: torch.device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(
backend=backend,
rank=rank,
world_size=world_size,
init_method=f"file://{os.path.join(path, 'dist_rdvz')}",
)
num_float_features = 32

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(3)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(2)
]
m = TestSparseNN(
tables=tables,
num_float_features=num_float_features,
weighted_tables=weighted_tables,
dense_device=device,
)
# Put all tensors on meta device, then init_params should
# materialize them.
for name, param in m._parameters.items():
if isinstance(param, torch.Tensor):
m._parameters[name] = torch.nn.Parameter(
torch.empty_like(param, device="meta"),
requires_grad=param.requires_grad,
def _run_init(cls, rank: int, world_size: int) -> None:
with MultiProcessContext(rank, world_size, "nccl") as ctx:
num_float_features = 32

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)

shard_modules(m, device=device, init_params=True)
# init_params should move m to `device`
for p in m.parameters():
assert p.device == device
for i in range(3)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(2)
]
m = TestSparseNN(
tables=tables,
num_float_features=num_float_features,
weighted_tables=weighted_tables,
dense_device=ctx.device,
)
# Put all tensors on meta device, then init_params should
# materialize them.
for name, param in m._parameters.items():
if isinstance(param, torch.Tensor):
m._parameters[name] = torch.nn.Parameter(
torch.empty_like(param, device="meta"),
requires_grad=param.requires_grad,
)

shard_modules(m, device=ctx.device, init_params=True)
# init_params should move m to `device`
for p in m.parameters():
assert p.device == ctx.device

@classmethod
def _run(cls, path: str) -> None:
rank = int(os.environ["LOCAL_RANK"])
world_size = int(os.environ["WORLD_SIZE"])
if torch.cuda.is_available():
device: torch.device = torch.device(f"cuda:{rank}")
backend = "nccl"
torch.cuda.set_device(device)
else:
device: torch.device = torch.device("cpu")
backend = "gloo"
dist.init_process_group(
backend=backend,
rank=rank,
world_size=world_size,
init_method=f"file://{os.path.join(path, 'dist_rdvz')}",
)
num_float_features = 32

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
def _run(cls, rank: int, world_size: int, path: str) -> None:
with MultiProcessContext(rank, world_size, "nccl") as ctx:
num_float_features = 32

tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="table_" + str(i),
feature_names=["feature_" + str(i)],
)
for i in range(3)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
)
for i in range(2)
]
m = TestSparseNN(
tables=tables,
num_float_features=num_float_features,
weighted_tables=weighted_tables,
dense_device=ctx.device,
)
for i in range(3)
]
weighted_tables = [
EmbeddingBagConfig(
num_embeddings=(i + 1) * 10,
embedding_dim=(i + 1) * 4 * world_size,
name="weighted_table_" + str(i),
feature_names=["weighted_feature_" + str(i)],
m.sparse.ebc = trec_shard(
module=m.sparse.ebc,
device=ctx.device,
plan=column_wise(ranks=list(range(world_size))),
)
for i in range(2)
]
m = TestSparseNN(
tables=tables,
num_float_features=num_float_features,
weighted_tables=weighted_tables,
dense_device=device,
)
m.sparse.ebc = trec_shard(
module=m.sparse.ebc,
device=device,
plan=column_wise(ranks=list(range(world_size))),
)
m.sparse.weighted_ebc = trec_shard(
module=m.sparse.weighted_ebc,
device=device,
plan=column_wise(ranks=list(range(world_size))),
)
m.over = replicate(m.over)
m.dense = replicate(m.dense)

######## run one iteration ########
_, local_batch = ModelInput.generate(
batch_size=8,
world_size=world_size,
num_float_features=num_float_features,
tables=tables,
weighted_tables=weighted_tables,
)
batch = local_batch[0].to(device)
m(batch)[1].sum().backward()

state_dict = m.state_dict()
writer = FileSystemWriter(path=path)
reader = FileSystemReader(path=path)
save_state_dict(state_dict, writer)

p_sum = torch.zeros(1, device=device)
for p in m.parameters():
with torch.no_grad():
if isinstance(p, ShardedTensor):
if not p.local_shards():
continue
p = p.local_tensor()
p_sum += p.sum()
p.zero_()
assert p.sum() == 0
load_state_dict(state_dict, reader)
m.load_state_dict(state_dict)

p_sum_loaded = torch.zeros(1, device=device)
for p in m.parameters():
with torch.no_grad():
if isinstance(p, ShardedTensor):
if not p.local_shards():
continue
p = p.local_tensor()
p_sum_loaded += p.sum()
# TODO: debug why failing on OSS
# assert p_sum.allclose(p_sum_loaded)
m.sparse.weighted_ebc = trec_shard(
module=m.sparse.weighted_ebc,
device=ctx.device,
plan=column_wise(ranks=list(range(world_size))),
)
m.over = replicate(m.over)
m.dense = replicate(m.dense)

######## run one iteration ########
_, local_batch = ModelInput.generate(
batch_size=8,
world_size=world_size,
num_float_features=num_float_features,
tables=tables,
weighted_tables=weighted_tables,
)
batch = local_batch[0].to(ctx.device)
m(batch)[1].sum().backward()

state_dict = m.state_dict()
writer = FileSystemWriter(path=path)
reader = FileSystemReader(path=path)
save_state_dict(state_dict, writer)

p_sum = torch.zeros(1, device=ctx.device)
for p in m.parameters():
with torch.no_grad():
if isinstance(p, ShardedTensor):
if not p.local_shards():
continue
p = p.local_tensor()
p_sum += p.sum()
p.zero_()
assert p.sum() == 0
load_state_dict(state_dict, reader)
m.load_state_dict(state_dict)

p_sum_loaded = torch.zeros(1, device=ctx.device)
for p in m.parameters():
with torch.no_grad():
if isinstance(p, ShardedTensor):
if not p.local_shards():
continue
p = p.local_tensor()
p_sum_loaded += p.sum()
# TODO: debug why failing on OSS
# assert p_sum.allclose(p_sum_loaded)

@skip_if_asan
# pyre-fixme[56]: Pyre was not able to infer the type of argument
Expand All @@ -195,18 +167,10 @@ def _run(cls, path: str) -> None:
)
def test_checkpoint(self) -> None:
with tempfile.TemporaryDirectory() as path:
lc = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=2,
run_id=str(uuid.uuid4()),
rdzv_backend="c10d",
rdzv_endpoint="localhost:0",
start_method="spawn",
monitor_interval=1,
max_restarts=0,
self._run_multi_process_test(
callable=self._run,
path=path,
)
elastic_launch(config=lc, entrypoint=self._run)(path)

@skip_if_asan
# pyre-fixme[56]: Pyre was not able to infer the type of argument
Expand All @@ -216,15 +180,6 @@ def test_checkpoint(self) -> None:
"Not enough GPUs, this test requires at least two GPUs",
)
def test_init_params(self) -> None:
with tempfile.TemporaryDirectory() as path:
lc = LaunchConfig(
min_nodes=1,
max_nodes=1,
nproc_per_node=2,
run_id=str(uuid.uuid4()),
rdzv_backend="c10d",
start_method="spawn",
monitor_interval=1,
max_restarts=0,
)
elastic_launch(config=lc, entrypoint=self._run_init_parameters)(path)
self._run_multi_process_test(
callable=self._run_init,
)
Loading