Skip to content

Commit

Permalink
[Dynamo] Enable torch._dynamo.config.suppress_errors by default (pyto…
Browse files Browse the repository at this point in the history
…rch#105307)

Summary:
We are working toward full model compilation, where when compilation error happens, we just fall back to eager mode rather than error out.
But at the same time, we should fix these issues if they are bugs. We will:
* 1/ log warnings in OSS;
* 2/ log warnings and write them into Scuba in fbcode;

to prevent us from ignoring these issues.

Test Plan: Manual test

Differential Revision: D47506314

Pull Request resolved: pytorch#105307
Approved by: https://github.com/jansel
  • Loading branch information
yanboliang authored and pytorchmergebot committed Jul 21, 2023
1 parent de8bd10 commit 4c73016
Show file tree
Hide file tree
Showing 8 changed files with 64 additions and 49 deletions.
8 changes: 4 additions & 4 deletions test/dynamo/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,20 +112,20 @@ def test_dynamo_debug_default_off_artifacts(self, records):
self.assertEqual(len([r for r in records if ".__bytecode" in r.name]), 0)
self.assertEqual(len([r for r in records if ".__output_code" in r.name]), 0)

@make_logging_test(dynamo=logging.ERROR)
@make_logging_test()
def test_dynamo_error(self, records):
try:
fn_opt = torch._dynamo.optimize("inductor")(dynamo_error_fn)
fn_opt(*ARGS)
except Exception:
pass
self.assertEqual(len(records), 1)
self.assertEqual(len(records), 2)

test_aot = within_range_record_test(2, 6, aot=logging.INFO)
test_inductor_debug = within_range_record_test(3, 15, inductor=logging.DEBUG)
test_inductor_info = within_range_record_test(2, 4, inductor=logging.INFO)

@make_logging_test(dynamo=logging.ERROR)
@make_logging_test()
def test_inductor_error(self, records):
exitstack = contextlib.ExitStack()
import torch._inductor.lowering
Expand All @@ -148,7 +148,7 @@ def throw(x):
fn_opt(*ARGS)
except Exception:
pass
self.assertEqual(len(records), 1)
self.assertEqual(len(records), 2)
self.assertIsInstance(records[0].msg, str)

exitstack.close()
Expand Down
2 changes: 2 additions & 0 deletions test/functorch/test_eager_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4738,6 +4738,8 @@ class TestCompileTransforms(TestCase):
# torch.compile is not supported on Windows
# Triton only supports GPU with SM70 or later.
@expectedFailureIf(IS_WINDOWS or (TEST_CUDA and not SM70OrLater))
@torch._dynamo.config.patch(suppress_errors=False)
@skipIfTorchDynamo("Do not test torch.compile on top of torch.compile")
def test_compile_vmap_hessian(self, device):
# The model and inputs are a smaller version
# of code at benchmark repo:
Expand Down
9 changes: 0 additions & 9 deletions test/inductor/test_cpu_repro.py
Original file line number Diff line number Diff line change
Expand Up @@ -1891,15 +1891,6 @@ def fn(x):
self.assertTrue(same(fn(x), opt_fn(x)))
assert metrics.generated_cpp_vec_kernel_count == 2

def test_invalid_index_of_empty_tensor(self):
def fn(a):
b = a[[0]]
return b

a = torch.tensor([])
with self.assertRaises(RuntimeError):
torch.compile(fn)(a)

def test_ir_node_str(self):
@torch.compile
def fn(x: torch.Tensor) -> torch.Tensor:
Expand Down
2 changes: 1 addition & 1 deletion torch/_dynamo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@
# This is a good way to get your model to work one way or another, but you may
# lose optimization opportunities this way. Devs, if your benchmark model is failing
# this way, you should figure out why instead of suppressing it.
suppress_errors = bool(os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", False))
suppress_errors = os.environ.get("TORCHDYNAMO_SUPPRESS_ERRORS", "1") == "1"

# Record and write an execution record of the current frame to a file
# if an exception is encountered
Expand Down
17 changes: 15 additions & 2 deletions torch/_dynamo/convert_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
augment_exc_message,
BackendCompilerFailed,
format_error_msg,
format_error_msg_verbose,
InternalTorchDynamoError,
TorchRuntimeError,
unimplemented,
Expand Down Expand Up @@ -215,7 +216,19 @@ def exception_handler(e, code, frame=None):
# Only log the exception if we are going to suppress it
# if aren't suppressing it, a higher level except block will handle it
if config.suppress_errors:
log.error(format_error_msg(e, code, record_filename, frame))
if config.is_fbcode():
from torch._dynamo.fb.logging import ( # type: ignore[import]
log_dynamo_suppress_errors,
)

error_msg = format_error_msg_verbose(e, code, record_filename, frame)
log_dynamo_suppress_errors(
code.co_name, code.co_filename, code.co_firstlineno, error_msg
)
else:
error_msg = format_error_msg(e, code, record_filename, frame)

log.warning(error_msg)


FRAME_COUNTER = 0
Expand Down Expand Up @@ -551,7 +564,7 @@ def _convert_frame(
except Exception:
if not config.suppress_errors:
raise
log.info("converting frame raised error, suppressing error")
log.warning("converting frame raised error, suppressing error")
return None

_convert_frame._torchdynamo_orig_callable = compiler_fn # type: ignore[attr-defined]
Expand Down
4 changes: 2 additions & 2 deletions torch/_dynamo/eval_frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def remove_from_cache(f):
elif hasattr(getattr(f, "forward", None), "__code__"):
reset_code(f.forward.__code__)
else:
from . import reset
from . import reset # type: ignore[attr-defined]

reset()
log.warning("could not determine __code__ for %s", f)
Expand Down Expand Up @@ -591,7 +591,7 @@ def toy_example(a, b):
@patch("torch._dynamo.symbolic_convert.explain", True)
def explain(f, *args, **kwargs):
# TODO(voz): Do we want a decorator for this?
from . import reset
from . import reset # type: ignore[attr-defined]

reset()

Expand Down
63 changes: 34 additions & 29 deletions torch/_dynamo/exc.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,39 +226,44 @@ def filter_stack(stack):
return user_stack


def format_error_msg(exc, code, record_filename=None, frame=None):
msg = os.linesep * 2
def format_error_msg_verbose(exc, code, record_filename=None, frame=None):
msg = str(
format_bytecode(
"WON'T CONVERT",
code.co_name,
code.co_filename,
code.co_firstlineno,
code,
)
)
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
msg += format_exc()
if hasattr(exc, "real_stack"):
msg += (
"\n"
+ "=" * 10
+ " The above exception occurred while processing the following code "
+ "=" * 10
+ "\n\n"
)
stack_above_dynamo = []
if frame is not None:
stack_above_dynamo = filter_stack(extract_stack(frame))

if config.verbose:
msg = str(
format_bytecode(
"WON'T CONVERT",
code.co_name,
code.co_filename,
code.co_firstlineno,
code,
)
msg += "".join(
format_list(stack_above_dynamo + list(reversed(get_real_stack(exc))))
)
msg += "=" * 10 + " TorchDynamo Stack Trace " + "=" * 10 + "\n"
msg += format_exc()
if hasattr(exc, "real_stack"):
msg += (
"\n"
+ "=" * 10
+ " The above exception occurred while processing the following code "
+ "=" * 10
+ "\n\n"
)
stack_above_dynamo = []
if frame is not None:
stack_above_dynamo = filter_stack(extract_stack(frame))
msg += "\n"
msg += "=" * 10

return msg

msg += "".join(
format_list(stack_above_dynamo + list(reversed(get_real_stack(exc))))
)
msg += "\n"
msg += "=" * 10

def format_error_msg(exc, code, record_filename=None, frame=None):
msg = os.linesep * 2

if config.verbose:
msg = format_error_msg_verbose(exec, code, record_filename, frame)
else:
msg = f"WON'T CONVERT {code.co_name} {code.co_filename}\
line {code.co_firstlineno} \ndue to: \n{format_exc(limit=-1)}"
Expand Down
8 changes: 6 additions & 2 deletions torch/testing/_internal/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,12 @@ def test_fn(self):
torch._dynamo.reset()
records = []
# run with env var
with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
fn(self, records)
if len(kwargs) == 0:
with self._handler_watcher(records):
fn(self, records)
else:
with log_settings(kwargs_to_settings(**kwargs)), self._handler_watcher(records):
fn(self, records)

# run with API
torch._dynamo.reset()
Expand Down

0 comments on commit 4c73016

Please sign in to comment.