Skip to content

Commit

Permalink
Allow (temporarily?) non-fake input during ONNX export with fake mode (
Browse files Browse the repository at this point in the history
…pytorch#105246)

Although input and model are expected to be fake during ONNX export with fake mode enabled, apparently some models can create new parameters during tracing. That makes internal checks on dynamo side to fail when we dont set `allow_non_fake_input=True` for `torch._dynamo.export`.

pytorch#105077 tracks this issue and if a proper fix is done, we will set `allow_non_fake_input=False` again

Additionally to that, a possible bug was found at torch.nn.Module.state_dict() in which some registered buffers are not listed.

This is being tracked by pytorch#105233 but in the mean time, we are merging `state_dict()` and `named_buffers()` results to create a full `state_dict` for the model

Two more complex/larger tests are added to the ONNX export which are the same for the experimental symbolic tracing: tiny gpt2 and toy mlp (https://github.com/pytorch/pytorch/blob/main/test/onnx/test_fx_to_onnx_with_onnxruntime.py#L766-L825)

ps: pytorch#105464 tracks pending tasks/limitations from this PR
Pull Request resolved: pytorch#105246
Approved by: https://github.com/BowenBao
  • Loading branch information
Thiago Crepaldi authored and pytorchmergebot committed Jul 21, 2023
1 parent 04da0c7 commit 842616b
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 10 deletions.
13 changes: 8 additions & 5 deletions test/onnx/test_fx_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,11 +361,14 @@ def forward(self, x):
)

# Scenario 6: Real model and fake input WITH fake_context
with pytest.raises(torch.onnx.OnnxExporterError):
export_options = ExportOptions(fake_context=fake_context)
_ = torch.onnx.dynamo_export(
real_model, fake_x, export_options=export_options
)
# TODO: Delete the test below if https://github.com/pytorch/pytorch/pull/105246
# Tracked by https://github.com/pytorch/pytorch/issues/105077
# mixed mode (real+fake) will be permanently allowed.
# with pytest.raises(torch.onnx.OnnxExporterError):
# export_options = ExportOptions(fake_context=fake_context)
# _ = torch.onnx.dynamo_export(
# real_model, fake_x, export_options=export_options
# )


if __name__ == "__main__":
Expand Down
79 changes: 76 additions & 3 deletions test/onnx/test_fx_to_onnx_with_onnxruntime.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,12 +860,19 @@ def _test_fake_tensor_mode_exporter(
) as tmp_checkpoint_file:
# Dump state_dict to a file to simulate how HuggingFace model is initialized.
# The file will be loaded via .load_state_dict(...)
torch.save(real_model.state_dict(), tmp_checkpoint_file.name)
state_dict = real_model.state_dict()
# TODO: Remove explicit named_bufefrs when # https://github.com/pytorch/pytorch/issues/105233 is fixed
state_dict.update(dict(real_model.named_buffers()))
torch.save(state_dict, tmp_checkpoint_file.name)

with torch.onnx.enable_fake_mode() as fake_context:
fake_args = create_args()
fake_kwargs = create_kwargs()
fake_model = create_model()
# TODO: Remove strict=False when https://github.com/pytorch/pytorch/issues/105233 is fixed
fake_model.load_state_dict(
torch.load(tmp_checkpoint_file.name), strict=False
)

# Export the model with fake inputs and parameters
export_options = torch.onnx.ExportOptions(
Expand All @@ -891,7 +898,9 @@ def _test_fake_tensor_mode_exporter(
real_model(*args, **kwargs)
)
# ORT outputs.
args_not_none = export_output.adapt_torch_inputs_to_onnx(*args)
args_not_none = export_output.adapt_torch_inputs_to_onnx(
*args, **kwargs
)

ort_outputs = onnx_test_common.run_ort(
tmp_onnx_file.name,
Expand All @@ -904,7 +913,7 @@ def _test_fake_tensor_mode_exporter(
torch.testing.assert_close(ref_output, torch.tensor(ort_output))

@pytorch_test_common.skip_op_level_debug_test(
"op_level_debug_test does not support FakeTensor yet."
"https://github.com/pytorch/pytorch/issues/105490"
)
def test_fake_tensor_mode_simple(self):
def create_model() -> nn.Module:
Expand Down Expand Up @@ -932,6 +941,70 @@ def create_pytorch_only_extra_kwargs():
create_pytorch_only_extra_kwargs,
)

@pytorch_test_common.skip_op_level_debug_test(
"https://github.com/pytorch/pytorch/issues/105490"
)
def test_large_scale_exporter_with_tiny_gpt2(self):
model_name = "sshleifer/tiny-gpt2"

def create_model() -> nn.Module:
return transformers.AutoModel.from_pretrained(model_name)

def create_args():
tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
kwargs = tokenizer("Hello world!", return_tensors="pt")
input_ids = kwargs["input_ids"]
attention_mask = kwargs["attention_mask"]
return input_ids, None, attention_mask

def create_pytorch_only_extra_kwargs():
return {"return_dict": False}

self._test_fake_tensor_mode_exporter(
"tiny_gpt2",
create_model,
create_args,
create_pytorch_only_extra_kwargs,
)

@pytorch_test_common.skip_op_level_debug_test(
"https://github.com/pytorch/pytorch/issues/105490"
)
def test_large_scale_exporter_with_toy_mlp(self):
class MLPModel(nn.Module):
def __init__(self):
super().__init__()
self.fc0 = nn.Linear(8, 8, bias=True)
self.fc1 = nn.Linear(8, 4, bias=True)
self.fc2 = nn.Linear(4, 2, bias=True)
self.fc3 = nn.Linear(2, 2, bias=True)

def forward(self, tensor_x: torch.Tensor):
tensor_x = self.fc0(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.fc1(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
tensor_x = self.fc2(tensor_x)
tensor_x = torch.sigmoid(tensor_x)
output = self.fc3(tensor_x)
return output

def create_model() -> nn.Module:
return MLPModel()

def create_args():
return (torch.rand((97, 8), dtype=torch.float32),)

def create_kwargs():
return {}

self._test_fake_tensor_mode_exporter(
"toy_mlp1",
create_model,
create_args,
create_kwargs,
)


if __name__ == "__main__":
common_utils.run_tests()
6 changes: 4 additions & 2 deletions torch/onnx/_internal/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,13 +273,15 @@ def enable_fake_mode():
# Ideally we should keep them in sync to preserve the same default behavior
# [1] `torch/_dynamo/output_graph.py::InstructionTranslator::OutputGraph.__init__`
fake_mode = fake_tensor.FakeTensorMode(
allow_non_fake_inputs=False,
allow_non_fake_inputs=True, # https://github.com/pytorch/pytorch/issues/105077
shape_env=ShapeEnv(
allow_scalar_outputs=False, allow_dynamic_output_shape_ops=False
),
)
# The patcher is needed for when user calls `fake_model.load_state_dict(...)` within fake mode
patcher_context = patcher.ONNXTorchPatcher()
fake_context = ONNXFakeContext(fake_mode=fake_mode)
with fake_mode:
with fake_mode, patcher_context:
yield fake_context


Expand Down
2 changes: 2 additions & 0 deletions torch/onnx/_internal/fx/dynamo_graph_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,8 @@ def generate_fx(
*model_args,
tracing_mode=fx_mode,
fake_mode=fake_mode, # type: ignore[arg-type]
aten_graph=fake_mode
is not None, # TODO: Tracked by https://github.com/pytorch/pytorch/issues/105467
**model_kwargs,
)
del graph_guard # Unused
Expand Down

0 comments on commit 842616b

Please sign in to comment.