Skip to content

Commit

Permalink
KJT tensor compute path for List[int] *per_key (#1867)
Browse files Browse the repository at this point in the history
Summary: Pull Request resolved: #1867

Differential Revision: D55932356
  • Loading branch information
Ivan Kobzarev authored and facebook-github-bot committed Apr 10, 2024
1 parent 579fe9f commit 32991fd
Show file tree
Hide file tree
Showing 3 changed files with 824 additions and 35 deletions.
2 changes: 2 additions & 0 deletions torchrec/distributed/test_utils/infer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def forward(
self,
values: torch.Tensor,
lengths: torch.Tensor,
stride_per_key_per_rank_tensor: Optional[torch.Tensor],
# pyre-ignore
*args,
# pyre-ignore
Expand All @@ -131,6 +132,7 @@ def forward(
keys=self._kjt_keys,
values=values,
lengths=lengths,
stride_per_key_per_rank_tensor=stride_per_key_per_rank_tensor,
)
output = self._module_kjt_input(kjt, *args, **kwargs)
# TODO(ivankobzarev): Support of None leaves in dynamo/export (e.g. KJT offsets)
Expand Down
143 changes: 120 additions & 23 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@

import sys
import unittest
from typing import List, Tuple
from typing import Dict, List, Tuple

import torch
from torchrec.sparse.jagged_tensor import JaggedTensor

try:
# pyre-ignore
Expand Down Expand Up @@ -50,7 +51,17 @@ def make_kjt(values: List[int], lengths: List[int]) -> KeyedJaggedTensor:
return kjt


def _sharded_quant_ebc_model() -> Tuple[torch.nn.Module, List[KeyedJaggedTensor]]:
def kjt_module_kjt_inputs(kjt: KeyedJaggedTensor) -> Tuple:
return (
kjt._values,
kjt._lengths,
kjt._stride_per_key_per_rank_tensor,
)


def _sharded_quant_ebc_model(
sharding_type: ShardingType,
) -> Tuple[torch.nn.Module, List[KeyedJaggedTensor]]:
num_embeddings = 256
emb_dim = 12
world_size = 2
Expand All @@ -73,7 +84,6 @@ def _sharded_quant_ebc_model() -> Tuple[torch.nn.Module, List[KeyedJaggedTensor]
]
model: torch.nn.Module = KJTInputExportWrapper(mi.quant_model, input_kjts[0].keys())

sharding_type: ShardingType = ShardingType.TABLE_WISE
sharder = TestQuantEBCSharder(
sharding_type=sharding_type.value,
kernel_type=EmbeddingComputeKernel.QUANT.value,
Expand All @@ -96,7 +106,44 @@ def _sharded_quant_ebc_model() -> Tuple[torch.nn.Module, List[KeyedJaggedTensor]
return sharded_model, input_kjts


def kjt_for_tracing(
kjt: KeyedJaggedTensor, always_to_vb: bool = False
) -> KeyedJaggedTensor:
is_vb = kjt.variable_stride_per_key()
if always_to_vb and not is_vb:
stride: int = kjt.stride()
n = len(kjt.keys())
return KeyedJaggedTensor(
keys=kjt.keys(),
values=kjt.values(),
lengths=kjt.lengths(),
stride_per_key_per_rank=[[stride]] * n,
inverse_indices=(
kjt.keys(),
torch.arange(stride)
.expand(n, stride)
.contiguous()
.to(device=kjt.device()),
),
stride_per_key_per_rank_tensor=torch.full([n], fill_value=stride).view(
n, 1
),
)

return KeyedJaggedTensor(
keys=kjt.keys(),
values=kjt.values(),
lengths=kjt.lengths(),
stride_per_key_per_rank=(kjt.stride_per_key_per_rank() if is_vb else None),
inverse_indices=kjt.inverse_indices_or_none() if is_vb else None,
stride_per_key_per_rank_tensor=torch.tensor(
kjt.stride_per_key_per_rank(), dtype=torch.int32
),
)


class TestPt2(unittest.TestCase):

def _test_kjt_input_module(
self,
kjt_input_module: torch.nn.Module,
Expand Down Expand Up @@ -135,21 +182,46 @@ def _test_kjt_input_module(
pt2_ir_output = pt2_ir.module()(*inputs)
assert_close(eager_output, pt2_ir_output)

# Separate test for Dynamo, as it fallbacks on VB path.
# Torchrec has lazy init modules, depending on the first input => we need to run eager with tracing inputs.
# But other test cases do not need to go VB.
def _test_kjt_input_module_dynamo_compile(
self,
kjt_input_module: torch.nn.Module,
kjt_keys: List[str],
# pyre-ignore
inputs,
) -> None:
with dynamo_skipfiles_allow("torchrec"):
EM: torch.nn.Module = KJTInputExportWrapper(kjt_input_module, kjt_keys)
eager_output = EM(*inputs)
torch._dynamo.config.capture_scalar_outputs = True
torch._dynamo.config.capture_dynamic_output_shape_ops = True

dynamo_eager_out = torch.compile(EM, backend="eager", fullgraph=True)(
*inputs
)
assert_close(eager_output, dynamo_eager_out)

def test_kjt_split(self) -> None:
class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor):
return kjt.split([1, 2, 1])

kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
segments: List[int] = [1, 2, 1]
self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths),
kjt_module_kjt_inputs(kjt),
test_aot_inductor=False,
test_dynamo=False,
test_pt2_ir_export=True,
)
self._test_kjt_input_module_dynamo_compile(
M(),
kjt.keys(),
kjt_module_kjt_inputs(kjt_for_tracing(kjt)),
)

def test_kjt_permute(self) -> None:
class M(torch.nn.Module):
Expand All @@ -158,13 +230,23 @@ def forward(self, kjt: KeyedJaggedTensor, indices: List[int]):

kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])
indices: List[int] = [1, 0, 3, 2]
# pyre-ignore
inputs_fn = lambda kjt: (
*kjt_module_kjt_inputs(kjt),
indices,
)
self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths, indices),
inputs_fn(kjt),
test_aot_inductor=False,
test_pt2_ir_export=True,
)
self._test_kjt_input_module_dynamo_compile(
M(),
kjt.keys(),
inputs_fn(kjt_for_tracing(kjt)),
)

def test_kjt_length_per_key(self) -> None:
class M(torch.nn.Module):
Expand All @@ -176,7 +258,7 @@ def forward(self, kjt: KeyedJaggedTensor):
self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths),
kjt_module_kjt_inputs(kjt),
test_aot_inductor=False,
test_pt2_ir_export=True,
)
Expand All @@ -191,7 +273,7 @@ def forward(self, kjt: KeyedJaggedTensor):
self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths),
kjt_module_kjt_inputs(kjt),
test_aot_inductor=False,
test_pt2_ir_export=True,
)
Expand All @@ -206,51 +288,65 @@ def forward(self, kjt: KeyedJaggedTensor):
return out0, out1

kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])

self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths),
kjt_module_kjt_inputs(kjt),
test_dynamo=False,
test_aot_inductor=False,
test_pt2_ir_export=True,
)
self._test_kjt_input_module_dynamo_compile(
M(),
kjt.keys(),
kjt_module_kjt_inputs(kjt_for_tracing(kjt)),
)

def test_kjt_to_dict(self) -> None:
class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor) -> Dict[str, JaggedTensor]:
return kjt.to_dict()

kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])

self._test_kjt_input_module_dynamo_compile(
M(),
kjt.keys(),
kjt_module_kjt_inputs(kjt_for_tracing(kjt)),
)

# pyre-ignores
@unittest.skipIf(
torch.cuda.device_count() <= 1,
"Not enough GPUs available",
)
def test_sharded_quant_ebc_dynamo_export_aot_inductor(self) -> None:
sharded_model, input_kjts = _sharded_quant_ebc_model()
sharded_model, input_kjts = _sharded_quant_ebc_model(ShardingType.TABLE_WISE)
kjt = input_kjts[0]
sharded_model(kjt.values(), kjt.lengths())
sharded_model(*kjt_module_kjt_inputs(kjt))

model: torch.nn.Module = sharded_model
model.training = False
replace_registered_tbes_with_mock_tbes(model)
replace_sharded_quant_modules_tbes_with_mock_tbes(model)

example_inputs = (kjt.values(), kjt.lengths())

example_inputs = kjt_module_kjt_inputs(kjt)
# pyre-ignore
def kjt_to_inputs(kjt):
return (kjt.values(), kjt.lengths())

expected_outputs = [model(*kjt_to_inputs(kjt)) for kjt in input_kjts[1:]]

expected_outputs = [
model(*kjt_module_kjt_inputs(kjt)) for kjt in input_kjts[1:]
]
device: str = "cuda"

with dynamo_skipfiles_allow("torchrec"):
tracing_values = kjt.values()
tracing_lengths = kjt.lengths()
torch._dynamo.mark_dynamic(tracing_values, 0)
dynamo_gm, guard = torch._dynamo.export(model, same_signature=False)(
tracing_values, tracing_lengths
tracing_values, tracing_lengths, kjt._stride_per_key_per_rank_tensor
)
dynamo_gm.print_readable()
dynamo_actual_outputs = [ # noqa
dynamo_gm(*kjt_to_inputs(kjt)) for kjt in input_kjts[1:]
dynamo_gm(*kjt_module_kjt_inputs(kjt)) for kjt in input_kjts[1:]
]
# TODO(ivankobzarev): Why dynamo outputs are different than expected, but aot outputs are correct.
# assert_close(expected_outputs, dynamo_actual_outputs)
Expand All @@ -265,7 +361,8 @@ def kjt_to_inputs(kjt):
aot_inductor_module(*example_inputs)

aot_actual_outputs = [
aot_inductor_module(*kjt_to_inputs(kjt)) for kjt in input_kjts[1:]
aot_inductor_module(*kjt_module_kjt_inputs(kjt))
for kjt in input_kjts[1:]
]
assert_close(expected_outputs, aot_actual_outputs)

Expand All @@ -274,7 +371,7 @@ def test_maybe_compute_kjt_to_jt_dict(self) -> None:
self._test_kjt_input_module(
ComputeKJTToJTDict(),
kjt.keys(),
(kjt._values, kjt._lengths),
(kjt._values, kjt._lengths, kjt._stride_per_key_per_rank_tensor),
# TODO: turn on AOT Inductor test once the support is ready
test_aot_inductor=False,
)
Loading

0 comments on commit 32991fd

Please sign in to comment.