From 0553aaf524e0cf6fa417a7348b7c5ccc3a3587ab Mon Sep 17 00:00:00 2001 From: Paul Zhang Date: Fri, 23 Feb 2024 13:35:14 -0800 Subject: [PATCH] Introduce pt2 ir export unit test Summary: Introduce unit testing torch.export Reviewed By: IvanKobzarev Differential Revision: D53426083 --- torchrec/distributed/tests/test_pt2.py | 31 ++++++++++++++++++++++---- 1 file changed, 27 insertions(+), 4 deletions(-) diff --git a/torchrec/distributed/tests/test_pt2.py b/torchrec/distributed/tests/test_pt2.py index a9238c0ac..c809691eb 100644 --- a/torchrec/distributed/tests/test_pt2.py +++ b/torchrec/distributed/tests/test_pt2.py @@ -103,17 +103,20 @@ def _test_kjt_input_module( kjt_keys: List[str], # pyre-ignore inputs, + test_dynamo: bool = True, test_aot_inductor: bool = True, + test_pt2_ir_export: bool = False, ) -> None: with dynamo_skipfiles_allow("torchrec"): EM: torch.nn.Module = KJTInputExportWrapper(kjt_input_module, kjt_keys) eager_output = EM(*inputs) - x = torch._dynamo.export(EM, same_signature=True)(*inputs) + if test_dynamo: + x = torch._dynamo.export(EM, same_signature=True)(*inputs) - export_gm = x.graph_module - export_gm_output = export_gm(*inputs) + export_gm = x.graph_module + export_gm_output = export_gm(*inputs) - assert_close(eager_output, export_gm_output) + assert_close(eager_output, export_gm_output) if test_aot_inductor: # pyre-ignore @@ -127,6 +130,11 @@ def _test_kjt_input_module( aot_actual_output = aot_inductor_module(*inputs) assert_close(eager_output, aot_actual_output) + if test_pt2_ir_export: + pt2_ir = torch.export.export(EM, inputs, {}, strict=False) + pt2_ir_output = pt2_ir(*inputs) + assert_close(eager_output, pt2_ir_output) + def test_kjt_split(self) -> None: class M(torch.nn.Module): def forward(self, kjt: KeyedJaggedTensor, segments: List[int]): @@ -155,6 +163,21 @@ def forward(self, kjt: KeyedJaggedTensor, indices: List[int]): test_aot_inductor=False, ) + def test_kjt_length_per_key(self) -> None: + class M(torch.nn.Module): + def forward(self, kjt: KeyedJaggedTensor): + return kjt.length_per_key() + + kjt: KeyedJaggedTensor = make_kjt([2, 3, 4, 5, 6], [1, 2, 1, 1]) + + self._test_kjt_input_module( + M(), + kjt.keys(), + (kjt._values, kjt._lengths), + test_aot_inductor=False, + test_pt2_ir_export=True, + ) + # pyre-ignore @unittest.skipIf( torch.cuda.device_count() <= 1,