From 842616bcba64e001ccc67b8f911939cc465364e9 Mon Sep 17 00:00:00 2001 From: Thiago Crepaldi Date: Thu, 20 Jul 2023 20:09:40 +0000 Subject: [PATCH] Allow (temporarily?) non-fake input during ONNX export with fake mode (#105246) Although input and model are expected to be fake during ONNX export with fake mode enabled, apparently some models can create new parameters during tracing. That makes internal checks on dynamo side to fail when we dont set `allow_non_fake_input=True` for `torch._dynamo.export`. https://github.com/pytorch/pytorch/issues/105077 tracks this issue and if a proper fix is done, we will set `allow_non_fake_input=False` again Additionally to that, a possible bug was found at torch.nn.Module.state_dict() in which some registered buffers are not listed. This is being tracked by https://github.com/pytorch/pytorch/issues/105233 but in the mean time, we are merging `state_dict()` and `named_buffers()` results to create a full `state_dict` for the model Two more complex/larger tests are added to the ONNX export which are the same for the experimental symbolic tracing: tiny gpt2 and toy mlp (https://github.com/pytorch/pytorch/blob/main/test/onnx/test_fx_to_onnx_with_onnxruntime.py#L766-L825) ps: https://github.com/pytorch/pytorch/issues/105464 tracks pending tasks/limitations from this PR Pull Request resolved: https://github.com/pytorch/pytorch/pull/105246 Approved by: https://github.com/BowenBao --- test/onnx/test_fx_to_onnx.py | 13 +-- test/onnx/test_fx_to_onnx_with_onnxruntime.py | 79 ++++++++++++++++++- torch/onnx/_internal/exporter.py | 6 +- .../_internal/fx/dynamo_graph_extractor.py | 2 + 4 files changed, 90 insertions(+), 10 deletions(-) diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index a16a1f634885d..4c7de5180cba2 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -361,11 +361,14 @@ def forward(self, x): ) # Scenario 6: Real model and fake input WITH fake_context - with pytest.raises(torch.onnx.OnnxExporterError): - export_options = ExportOptions(fake_context=fake_context) - _ = torch.onnx.dynamo_export( - real_model, fake_x, export_options=export_options - ) + # TODO: Delete the test below if https://github.com/pytorch/pytorch/pull/105246 + # Tracked by https://github.com/pytorch/pytorch/issues/105077 + # mixed mode (real+fake) will be permanently allowed. + # with pytest.raises(torch.onnx.OnnxExporterError): + # export_options = ExportOptions(fake_context=fake_context) + # _ = torch.onnx.dynamo_export( + # real_model, fake_x, export_options=export_options + # ) if __name__ == "__main__": diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 04d1de17cef3e..0d04848cdf111 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -860,12 +860,19 @@ def _test_fake_tensor_mode_exporter( ) as tmp_checkpoint_file: # Dump state_dict to a file to simulate how HuggingFace model is initialized. # The file will be loaded via .load_state_dict(...) - torch.save(real_model.state_dict(), tmp_checkpoint_file.name) + state_dict = real_model.state_dict() + # TODO: Remove explicit named_bufefrs when # https://github.com/pytorch/pytorch/issues/105233 is fixed + state_dict.update(dict(real_model.named_buffers())) + torch.save(state_dict, tmp_checkpoint_file.name) with torch.onnx.enable_fake_mode() as fake_context: fake_args = create_args() fake_kwargs = create_kwargs() fake_model = create_model() + # TODO: Remove strict=False when https://github.com/pytorch/pytorch/issues/105233 is fixed + fake_model.load_state_dict( + torch.load(tmp_checkpoint_file.name), strict=False + ) # Export the model with fake inputs and parameters export_options = torch.onnx.ExportOptions( @@ -891,7 +898,9 @@ def _test_fake_tensor_mode_exporter( real_model(*args, **kwargs) ) # ORT outputs. - args_not_none = export_output.adapt_torch_inputs_to_onnx(*args) + args_not_none = export_output.adapt_torch_inputs_to_onnx( + *args, **kwargs + ) ort_outputs = onnx_test_common.run_ort( tmp_onnx_file.name, @@ -904,7 +913,7 @@ def _test_fake_tensor_mode_exporter( torch.testing.assert_close(ref_output, torch.tensor(ort_output)) @pytorch_test_common.skip_op_level_debug_test( - "op_level_debug_test does not support FakeTensor yet." + "https://github.com/pytorch/pytorch/issues/105490" ) def test_fake_tensor_mode_simple(self): def create_model() -> nn.Module: @@ -932,6 +941,70 @@ def create_pytorch_only_extra_kwargs(): create_pytorch_only_extra_kwargs, ) + @pytorch_test_common.skip_op_level_debug_test( + "https://github.com/pytorch/pytorch/issues/105490" + ) + def test_large_scale_exporter_with_tiny_gpt2(self): + model_name = "sshleifer/tiny-gpt2" + + def create_model() -> nn.Module: + return transformers.AutoModel.from_pretrained(model_name) + + def create_args(): + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + kwargs = tokenizer("Hello world!", return_tensors="pt") + input_ids = kwargs["input_ids"] + attention_mask = kwargs["attention_mask"] + return input_ids, None, attention_mask + + def create_pytorch_only_extra_kwargs(): + return {"return_dict": False} + + self._test_fake_tensor_mode_exporter( + "tiny_gpt2", + create_model, + create_args, + create_pytorch_only_extra_kwargs, + ) + + @pytorch_test_common.skip_op_level_debug_test( + "https://github.com/pytorch/pytorch/issues/105490" + ) + def test_large_scale_exporter_with_toy_mlp(self): + class MLPModel(nn.Module): + def __init__(self): + super().__init__() + self.fc0 = nn.Linear(8, 8, bias=True) + self.fc1 = nn.Linear(8, 4, bias=True) + self.fc2 = nn.Linear(4, 2, bias=True) + self.fc3 = nn.Linear(2, 2, bias=True) + + def forward(self, tensor_x: torch.Tensor): + tensor_x = self.fc0(tensor_x) + tensor_x = torch.sigmoid(tensor_x) + tensor_x = self.fc1(tensor_x) + tensor_x = torch.sigmoid(tensor_x) + tensor_x = self.fc2(tensor_x) + tensor_x = torch.sigmoid(tensor_x) + output = self.fc3(tensor_x) + return output + + def create_model() -> nn.Module: + return MLPModel() + + def create_args(): + return (torch.rand((97, 8), dtype=torch.float32),) + + def create_kwargs(): + return {} + + self._test_fake_tensor_mode_exporter( + "toy_mlp1", + create_model, + create_args, + create_kwargs, + ) + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py index aae7eb6f8e380..b2f3e57174de9 100644 --- a/torch/onnx/_internal/exporter.py +++ b/torch/onnx/_internal/exporter.py @@ -273,13 +273,15 @@ def enable_fake_mode(): # Ideally we should keep them in sync to preserve the same default behavior # [1] `torch/_dynamo/output_graph.py::InstructionTranslator::OutputGraph.__init__` fake_mode = fake_tensor.FakeTensorMode( - allow_non_fake_inputs=False, + allow_non_fake_inputs=True, # https://github.com/pytorch/pytorch/issues/105077 shape_env=ShapeEnv( allow_scalar_outputs=False, allow_dynamic_output_shape_ops=False ), ) + # The patcher is needed for when user calls `fake_model.load_state_dict(...)` within fake mode + patcher_context = patcher.ONNXTorchPatcher() fake_context = ONNXFakeContext(fake_mode=fake_mode) - with fake_mode: + with fake_mode, patcher_context: yield fake_context diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py index 1b4a2d250eadd..5b24b88d20907 100644 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -194,6 +194,8 @@ def generate_fx( *model_args, tracing_mode=fx_mode, fake_mode=fake_mode, # type: ignore[arg-type] + aten_graph=fake_mode + is not None, # TODO: Tracked by https://github.com/pytorch/pytorch/issues/105467 **model_kwargs, ) del graph_guard # Unused