From 4c73016ff2e8b8706080ac39b1d9fb822ce4a365 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 21 Jul 2023 19:17:44 +0000 Subject: [PATCH] [Dynamo] Enable torch._dynamo.config.suppress_errors by default (#105307) Summary: We are working toward full model compilation, where when compilation error happens, we just fall back to eager mode rather than error out. But at the same time, we should fix these issues if they are bugs. We will: * 1/ log warnings in OSS; * 2/ log warnings and write them into Scuba in fbcode; to prevent us from ignoring these issues. Test Plan: Manual test Differential Revision: D47506314 Pull Request resolved: https://github.com/pytorch/pytorch/pull/105307 Approved by: https://github.com/jansel --- test/dynamo/test_logging.py | 8 +-- test/functorch/test_eager_transforms.py | 2 + test/inductor/test_cpu_repro.py | 9 ---- torch/_dynamo/config.py | 2 +- torch/_dynamo/convert_frame.py | 17 ++++++- torch/_dynamo/eval_frame.py | 4 +- torch/_dynamo/exc.py | 63 +++++++++++++----------- torch/testing/_internal/logging_utils.py | 8 ++- 8 files changed, 64 insertions(+), 49 deletions(-) diff --git a/test/dynamo/test_logging.py b/test/dynamo/test_logging.py index 3bd02dcce3f67..e048d77d2dbfa 100644 --- a/test/dynamo/test_logging.py +++ b/test/dynamo/test_logging.py @@ -112,20 +112,20 @@ def test_dynamo_debug_default_off_artifacts(self, records): self.assertEqual(len([r for r in records if ".__bytecode" in r.name]), 0) self.assertEqual(len([r for r in records if ".__output_code" in r.name]), 0) - @make_logging_test(dynamo=logging.ERROR) + @make_logging_test() def test_dynamo_error(self, records): try: fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn) fn_opt(*ARGS) except Exception: pass - self.assertEqual(len(records), 1) + self.assertEqual(len(records), 2) test_aot = within_range_record_test(2, 6, aot=logging.INFO) test_inductor_debug = within_range_record_test(3, 15, inductor=logging.DEBUG) test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO) - @make_logging_test(dynamo=logging.ERROR) + @make_logging_test() def test_inductor_error(self, records): exitstack = contextlib.ExitStack() import torch._inductor.lowering @@ -148,7 +148,7 @@ def throw(x): fn_opt(*ARGS) except Exception: pass - self.assertEqual(len(records), 1) + self.assertEqual(len(records), 2) self.assertIsInstance(records[0].msg, str) exitstack.close() diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 217e76d45010b..d4d148c261ec4 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -4738,6 +4738,8 @@ class TestCompileTransforms(TestCase): # torch.compile is not supported on Windows # Triton only supports GPU with SM70 or later. @expectedFailureIf(IS_WINDOWS or (TEST_CUDA and not SM70OrLater)) + @torch._dynamo.config.patch(suppress_errors=False) + @skipIfTorchDynamo("Do not test torch.compile on top of torch.compile") def test_compile_vmap_hessian(self, device): # The model and inputs are a smaller version # of code at benchmark repo: diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 5908b9b8a54be..5a47b644c606a 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -1891,15 +1891,6 @@ def fn(x): self.assertTrue(same(fn(x), opt_fn(x))) assert metrics.generated_cpp_vec_kernel_count == 2 - def test_invalid_index_of_empty_tensor(self): - def fn(a): - b = a[[0]] - return b - - a = torch.tensor([]) - with self.assertRaises(RuntimeError): - torch.compile(fn)(a) - def test_ir_node_str(self): @torch.compile def fn(x: torch.Tensor) -> torch.Tensor: diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index d39e8f4fa0cdc..3652cbbba02c2 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -98,7 +98,7 @@ # This is a good way to get your model to work one way or another, but you may # lose optimization opportunities this way. Devs, if your benchmark model is failing # this way, you should figure out why instead of suppressing it. -suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False)) +suppress_errors = os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", "1") == "1" # Record and write an execution record of the current frame to a file # if an exception is encountered diff --git a/torch/_dynamo/convert_frame.py b/torch/_dynamo/convert_frame.py index 92899249cfd70..ab0b4ccb0662d 100644 --- a/torch/_dynamo/convert_frame.py +++ b/torch/_dynamo/convert_frame.py @@ -33,6 +33,7 @@ augment_exc_message, BackendCompilerFailed, format_error_msg, + format_error_msg_verbose, InternalTorchDynamoError, TorchRuntimeError, unimplemented, @@ -215,7 +216,19 @@ def exception_handler(e, code, frame=None): # Only log the exception if we are going to suppress it # if aren't suppressing it, a higher level except block will handle it if config.suppress_errors: - log.error(format_error_msg(e, code, record_filename, frame)) + if config.is_fbcode(): + from torch._dynamo.fb.logging import ( # type: ignore[import] + log_dynamo_suppress_errors, + ) + + error_msg = format_error_msg_verbose(e, code, record_filename, frame) + log_dynamo_suppress_errors( + code.co_name, code.co_filename, code.co_firstlineno, error_msg + ) + else: + error_msg = format_error_msg(e, code, record_filename, frame) + + log.warning(error_msg) FRAME_COUNTER = 0 @@ -551,7 +564,7 @@ def _convert_frame( except Exception: if not config.suppress_errors: raise - log.info("converting frame raised error, suppressing error") + log.warning("converting frame raised error, suppressing error") return None _convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined] diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 0c0a340e59c02..f21039d1909b6 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -160,7 +160,7 @@ def remove_from_cache(f): elif hasattr(getattr(f, "forward", None), "__code__"): reset_code(f.forward.__code__) else: - from . import reset + from . import reset # type: ignore[attr-defined] reset() log.warning("could not determine __code__ for %s", f) @@ -591,7 +591,7 @@ def toy_example(a, b): @patch("torch._dynamo.symbolic_convert.explain", True) def explain(f, *args, **kwargs): # TODO(voz): Do we want a decorator for this? - from . import reset + from . import reset # type: ignore[attr-defined] reset() diff --git a/torch/_dynamo/exc.py b/torch/_dynamo/exc.py index 7d8c68e5ebdc6..bb8f38e5abd5c 100644 --- a/torch/_dynamo/exc.py +++ b/torch/_dynamo/exc.py @@ -226,39 +226,44 @@ def filter_stack(stack): return user_stack -def format_error_msg(exc, code, record_filename=None, frame=None): - msg = os.linesep * 2 +def format_error_msg_verbose(exc, code, record_filename=None, frame=None): + msg = str( + format_bytecode( + "WON'T CONVERT", + code.co_name, + code.co_filename, + code.co_firstlineno, + code, + ) + ) + msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n" + msg += format_exc() + if hasattr(exc, "real_stack"): + msg += ( + "\n" + + "=" * 10 + + " The above exception occurred while processing the following code " + + "=" * 10 + + "\n\n" + ) + stack_above_dynamo = [] + if frame is not None: + stack_above_dynamo = filter_stack(extract_stack(frame)) - if config.verbose: - msg = str( - format_bytecode( - "WON'T CONVERT", - code.co_name, - code.co_filename, - code.co_firstlineno, - code, - ) + msg += "".join( + format_list(stack_above_dynamo + list(reversed(get_real_stack(exc)))) ) - msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n" - msg += format_exc() - if hasattr(exc, "real_stack"): - msg += ( - "\n" - + "=" * 10 - + " The above exception occurred while processing the following code " - + "=" * 10 - + "\n\n" - ) - stack_above_dynamo = [] - if frame is not None: - stack_above_dynamo = filter_stack(extract_stack(frame)) + msg += "\n" + msg += "=" * 10 + + return msg - msg += "".join( - format_list(stack_above_dynamo + list(reversed(get_real_stack(exc)))) - ) - msg += "\n" - msg += "=" * 10 +def format_error_msg(exc, code, record_filename=None, frame=None): + msg = os.linesep * 2 + + if config.verbose: + msg = format_error_msg_verbose(exec, code, record_filename, frame) else: msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\ line {code.co_firstlineno} \ndue to: \n{format_exc(limit=-1)}" diff --git a/torch/testing/_internal/logging_utils.py b/torch/testing/_internal/logging_utils.py index acf8eca3e3cd5..c34efdedb71e9 100644 --- a/torch/testing/_internal/logging_utils.py +++ b/torch/testing/_internal/logging_utils.py @@ -75,8 +75,12 @@ def test_fn(self): torch._dynamo.reset() records = [] # run with env var - with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records): - fn(self, records) + if len(kwargs) == 0: + with self._handler_watcher(records): + fn(self, records) + else: + with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records): + fn(self, records) # run with API torch._dynamo.reset()