diff --git a/test/onnx/test_fx_to_onnx.py b/test/onnx/test_fx_to_onnx.py index a16a1f634885d..4c7de5180cba2 100644 --- a/test/onnx/test_fx_to_onnx.py +++ b/test/onnx/test_fx_to_onnx.py @@ -361,11 +361,14 @@ def forward(self, x): ) # Scenario 6: Real model and fake input WITH fake_context - with pytest.raises(torch.onnx.OnnxExporterError): - export_options = ExportOptions(fake_context=fake_context) - _ = torch.onnx.dynamo_export( - real_model, fake_x, export_options=export_options - ) + # TODO: Delete the test below if https://github.com/pytorch/pytorch/pull/105246 + # Tracked by https://github.com/pytorch/pytorch/issues/105077 + # mixed mode (real+fake) will be permanently allowed. + # with pytest.raises(torch.onnx.OnnxExporterError): + # export_options = ExportOptions(fake_context=fake_context) + # _ = torch.onnx.dynamo_export( + # real_model, fake_x, export_options=export_options + # ) if __name__ == "__main__": diff --git a/test/onnx/test_fx_to_onnx_with_onnxruntime.py b/test/onnx/test_fx_to_onnx_with_onnxruntime.py index 04d1de17cef3e..0d04848cdf111 100644 --- a/test/onnx/test_fx_to_onnx_with_onnxruntime.py +++ b/test/onnx/test_fx_to_onnx_with_onnxruntime.py @@ -860,12 +860,19 @@ def _test_fake_tensor_mode_exporter( ) as tmp_checkpoint_file: # Dump state_dict to a file to simulate how HuggingFace model is initialized. # The file will be loaded via .load_state_dict(...) - torch.save(real_model.state_dict(), tmp_checkpoint_file.name) + state_dict = real_model.state_dict() + # TODO: Remove explicit named_bufefrs when # https://github.com/pytorch/pytorch/issues/105233 is fixed + state_dict.update(dict(real_model.named_buffers())) + torch.save(state_dict, tmp_checkpoint_file.name) with torch.onnx.enable_fake_mode() as fake_context: fake_args = create_args() fake_kwargs = create_kwargs() fake_model = create_model() + # TODO: Remove strict=False when https://github.com/pytorch/pytorch/issues/105233 is fixed + fake_model.load_state_dict( + torch.load(tmp_checkpoint_file.name), strict=False + ) # Export the model with fake inputs and parameters export_options = torch.onnx.ExportOptions( @@ -891,7 +898,9 @@ def _test_fake_tensor_mode_exporter( real_model(*args, **kwargs) ) # ORT outputs. - args_not_none = export_output.adapt_torch_inputs_to_onnx(*args) + args_not_none = export_output.adapt_torch_inputs_to_onnx( + *args, **kwargs + ) ort_outputs = onnx_test_common.run_ort( tmp_onnx_file.name, @@ -904,7 +913,7 @@ def _test_fake_tensor_mode_exporter( torch.testing.assert_close(ref_output, torch.tensor(ort_output)) @pytorch_test_common.skip_op_level_debug_test( - "op_level_debug_test does not support FakeTensor yet." + "https://github.com/pytorch/pytorch/issues/105490" ) def test_fake_tensor_mode_simple(self): def create_model() -> nn.Module: @@ -932,6 +941,70 @@ def create_pytorch_only_extra_kwargs(): create_pytorch_only_extra_kwargs, ) + @pytorch_test_common.skip_op_level_debug_test( + "https://github.com/pytorch/pytorch/issues/105490" + ) + def test_large_scale_exporter_with_tiny_gpt2(self): + model_name = "sshleifer/tiny-gpt2" + + def create_model() -> nn.Module: + return transformers.AutoModel.from_pretrained(model_name) + + def create_args(): + tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) + kwargs = tokenizer("Hello world!", return_tensors="pt") + input_ids = kwargs["input_ids"] + attention_mask = kwargs["attention_mask"] + return input_ids, None, attention_mask + + def create_pytorch_only_extra_kwargs(): + return {"return_dict": False} + + self._test_fake_tensor_mode_exporter( + "tiny_gpt2", + create_model, + create_args, + create_pytorch_only_extra_kwargs, + ) + + @pytorch_test_common.skip_op_level_debug_test( + "https://github.com/pytorch/pytorch/issues/105490" + ) + def test_large_scale_exporter_with_toy_mlp(self): + class MLPModel(nn.Module): + def __init__(self): + super().__init__() + self.fc0 = nn.Linear(8, 8, bias=True) + self.fc1 = nn.Linear(8, 4, bias=True) + self.fc2 = nn.Linear(4, 2, bias=True) + self.fc3 = nn.Linear(2, 2, bias=True) + + def forward(self, tensor_x: torch.Tensor): + tensor_x = self.fc0(tensor_x) + tensor_x = torch.sigmoid(tensor_x) + tensor_x = self.fc1(tensor_x) + tensor_x = torch.sigmoid(tensor_x) + tensor_x = self.fc2(tensor_x) + tensor_x = torch.sigmoid(tensor_x) + output = self.fc3(tensor_x) + return output + + def create_model() -> nn.Module: + return MLPModel() + + def create_args(): + return (torch.rand((97, 8), dtype=torch.float32),) + + def create_kwargs(): + return {} + + self._test_fake_tensor_mode_exporter( + "toy_mlp1", + create_model, + create_args, + create_kwargs, + ) + if __name__ == "__main__": common_utils.run_tests() diff --git a/torch/onnx/_internal/exporter.py b/torch/onnx/_internal/exporter.py index aae7eb6f8e380..b2f3e57174de9 100644 --- a/torch/onnx/_internal/exporter.py +++ b/torch/onnx/_internal/exporter.py @@ -273,13 +273,15 @@ def enable_fake_mode(): # Ideally we should keep them in sync to preserve the same default behavior # [1] `torch/_dynamo/output_graph.py::InstructionTranslator::OutputGraph.__init__` fake_mode = fake_tensor.FakeTensorMode( - allow_non_fake_inputs=False, + allow_non_fake_inputs=True, # https://github.com/pytorch/pytorch/issues/105077 shape_env=ShapeEnv( allow_scalar_outputs=False, allow_dynamic_output_shape_ops=False ), ) + # The patcher is needed for when user calls `fake_model.load_state_dict(...)` within fake mode + patcher_context = patcher.ONNXTorchPatcher() fake_context = ONNXFakeContext(fake_mode=fake_mode) - with fake_mode: + with fake_mode, patcher_context: yield fake_context diff --git a/torch/onnx/_internal/fx/dynamo_graph_extractor.py b/torch/onnx/_internal/fx/dynamo_graph_extractor.py index 1b4a2d250eadd..5b24b88d20907 100644 --- a/torch/onnx/_internal/fx/dynamo_graph_extractor.py +++ b/torch/onnx/_internal/fx/dynamo_graph_extractor.py @@ -194,6 +194,8 @@ def generate_fx( *model_args, tracing_mode=fx_mode, fake_mode=fake_mode, # type: ignore[arg-type] + aten_graph=fake_mode + is not None, # TODO: Tracked by https://github.com/pytorch/pytorch/issues/105467 **model_kwargs, ) del graph_guard # Unused