Skip to content

Commit

Permalink
Introduce pt2 ir export unit test
Browse files Browse the repository at this point in the history
Summary: Introduce unit testing torch.export

Reviewed By: IvanKobzarev

Differential Revision: D53426083
  • Loading branch information
PaulZhang12 authored and facebook-github-bot committed Feb 23, 2024
1 parent 9604533 commit 0553aaf
Showing 1 changed file with 27 additions and 4 deletions.
31 changes: 27 additions & 4 deletions torchrec/distributed/tests/test_pt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,20 @@ def _test_kjt_input_module(
kjt_keys: List[str],
# pyre-ignore
inputs,
test_dynamo: bool = True,
test_aot_inductor: bool = True,
test_pt2_ir_export: bool = False,
) -> None:
with dynamo_skipfiles_allow("torchrec"):
EM: torch.nn.Module = KJTInputExportWrapper(kjt_input_module, kjt_keys)
eager_output = EM(*inputs)
x = torch._dynamo.export(EM, same_signature=True)(*inputs)
if test_dynamo:
x = torch._dynamo.export(EM, same_signature=True)(*inputs)

export_gm = x.graph_module
export_gm_output = export_gm(*inputs)
export_gm = x.graph_module
export_gm_output = export_gm(*inputs)

assert_close(eager_output, export_gm_output)
assert_close(eager_output, export_gm_output)

if test_aot_inductor:
# pyre-ignore
Expand All @@ -127,6 +130,11 @@ def _test_kjt_input_module(
aot_actual_output = aot_inductor_module(*inputs)
assert_close(eager_output, aot_actual_output)

if test_pt2_ir_export:
pt2_ir = torch.export.export(EM, inputs, {}, strict=False)
pt2_ir_output = pt2_ir(*inputs)
assert_close(eager_output, pt2_ir_output)

def test_kjt_split(self) -> None:
class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor, segments: List[int]):
Expand Down Expand Up @@ -155,6 +163,21 @@ def forward(self, kjt: KeyedJaggedTensor, indices: List[int]):
test_aot_inductor=False,
)

def test_kjt_length_per_key(self) -> None:
class M(torch.nn.Module):
def forward(self, kjt: KeyedJaggedTensor):
return kjt.length_per_key()

kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1])

self._test_kjt_input_module(
M(),
kjt.keys(),
(kjt._values, kjt._lengths),
test_aot_inductor=False,
test_pt2_ir_export=True,
)

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 1,
Expand Down

0 comments on commit 0553aaf

Please sign in to comment.