Skip to content

Commit

Permalink
add test missing feature in rank for EC (#1879)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #1879

As titled

Reviewed By: IvanKobzarev

Differential Revision: D54314395

fbshipit-source-id: 20f0e0c8cb1616bb53bc69e40417bbc4ca4e4377
  • Loading branch information
gnahzg authored and facebook-github-bot committed Apr 17, 2024
1 parent f2e6ef0 commit b357837
Showing 1 changed file with 109 additions and 0 deletions.
109 changes: 109 additions & 0 deletions torchrec/distributed/tests/test_infer_shardings.py
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,115 @@ def test_mix_tw_rw_sequence(self, weight_dtype: torch.dtype) -> None:
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
# pyre-ignore
@given(
weight_dtype=st.sampled_from([torch.qint8, torch.quint4x2]),
)
@settings(max_examples=4, deadline=None)
def test_mix_tw_rw_sequence_missing_feature_on_rank(
self, weight_dtype: torch.dtype
) -> None:
num_embeddings = 10
emb_dim = 16
world_size = 2
local_size = 2
batch_size = 2
local_device = torch.device("cuda:0")

topology: Topology = Topology(world_size=world_size, compute_device="cuda")
mi = TestModelInfo(
dense_device=local_device,
sparse_device=local_device,
num_features=2,
num_float_features=10,
num_weighted_features=0,
topology=topology,
)

mi.planner = EmbeddingShardingPlanner(
topology=topology,
batch_size=batch_size,
enumerator=EmbeddingEnumerator(
topology=topology,
batch_size=batch_size,
estimator=[
EmbeddingPerfEstimator(topology=topology, is_inference=True),
EmbeddingStorageEstimator(topology=topology),
],
),
)

mi.tables = [
EmbeddingConfig(
num_embeddings=num_embeddings,
embedding_dim=emb_dim,
name=f"table_{i}",
feature_names=[f"feature_{i}"],
)
for i in range(mi.num_features)
]

mi.model = KJTInputWrapper(
module_kjt_input=torch.nn.Sequential(
EmbeddingCollection(
tables=mi.tables,
device=mi.sparse_device,
)
)
)

mi.model.training = False
mi.quant_model = quantize(
mi.model,
inplace=False,
quant_state_dict_split_scale_bias=True,
weight_dtype=weight_dtype,
)
non_sharded_model = mi.quant_model

sharder = QuantEmbeddingCollectionSharder()

module_plan = construct_module_sharding_plan(
non_sharded_model._module_kjt_input[0],
per_param_sharding={
"table_0": row_wise(),
"table_1": table_wise(rank=1),
},
# pyre-ignore
sharder=sharder,
local_size=local_size,
world_size=world_size,
)

plan = ShardingPlan(plan={"_module_kjt_input.0": module_plan})

sharded_model = shard_qec(
mi,
sharding_type=ShardingType.ROW_WISE,
device=local_device,
plan=plan,
expected_shards=None,
)

inputs = [
model_input_to_forward_args_kjt(inp.to(local_device))
for inp in prep_inputs(mi, world_size, batch_size, long_indices=False)
]

sharded_model.load_state_dict(non_sharded_model.state_dict())
sharded_output = sharded_model(*inputs[0])
non_sharded_output = non_sharded_model(*inputs[0])
assert_close(non_sharded_output, sharded_output)

gm: torch.fx.GraphModule = symbolic_trace(sharded_model)
gm_script = torch.jit.script(gm)
gm_script_output = gm_script(*inputs[0])
assert_close(sharded_output, gm_script_output)

@unittest.skipIf(
torch.cuda.device_count() <= 2,
"Not enough GPUs available",
Expand Down

0 comments on commit b357837

Please sign in to comment.