From 2d841bcb9f567f7acb9e400fd1d7d9712df343e6 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Tue, 15 Aug 2023 05:32:39 +0000 Subject: [PATCH] Remove type promotion workaround (#107202) Removes old type promotion workaround Pull Request resolved: https://github.com/pytorch/pytorch/pull/107202 Approved by: https://github.com/xuzhao9, https://github.com/eellison --- test/inductor/test_foreach.py | 2 +- torch/_inductor/lowering.py | 18 +----------------- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/test/inductor/test_foreach.py b/test/inductor/test_foreach.py index 8c8bdc40c23e53..17301989bc0387 100644 --- a/test/inductor/test_foreach.py +++ b/test/inductor/test_foreach.py @@ -223,7 +223,7 @@ def fn(a0, a1, b0, b1): actual = fn_opt(*inputs) expected = fn(*inputs) self.assertEqual(actual, expected) - self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1) @requires_cuda() @bin_ops diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 9218f5280e2dc0..932b3576dc3c3e 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -423,28 +423,12 @@ def is_dynamic(*args): for t in args ) - def has_type_promotion(*args): - if len(args) < 2: - return False - else: - dtype = None - for t in args: - if isinstance(t, TensorBox): - if dtype is None: - dtype = t.data.get_dtype() # type: ignore[attr-defined] - elif dtype != t.data.get_dtype(): - return True - return False - # group by device, whether any of the inputs are dynamic, and whether their types match # (proxy for type promotion) - # Note: we'll fallback on type promotion until - # https://github.com/openai/triton/commit/9820899b3845e461d9031dba66062efade65d420 - # is in the pytorch triton version def group_args(arg_pairs): out = defaultdict(list) for i, args in enumerate(arg_pairs): - use_foreach = not (is_dynamic(*args) or has_type_promotion(*args)) + use_foreach = not is_dynamic(*args) device = None for t in args: if isinstance(t, TensorBox):