Skip to content

Commit

Permalink
Remove type promotion workaround (pytorch#107202)
Browse files Browse the repository at this point in the history
Removes old type promotion workaround

Pull Request resolved: pytorch#107202
Approved by: https://github.com/xuzhao9, https://github.com/eellison
  • Loading branch information
mlazos authored and pytorchmergebot committed Aug 15, 2023
1 parent c9c9076 commit 2d841bc
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 18 deletions.
2 changes: 1 addition & 1 deletion test/inductor/test_foreach.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ def fn(a0, a1, b0, b1):
actual = fn_opt(*inputs)
expected = fn(*inputs)
self.assertEqual(actual, expected)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 2)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 1)

@requires_cuda()
@bin_ops
Expand Down
18 changes: 1 addition & 17 deletions torch/_inductor/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,28 +423,12 @@ def is_dynamic(*args):
for t in args
)

def has_type_promotion(*args):
if len(args) < 2:
return False
else:
dtype = None
for t in args:
if isinstance(t, TensorBox):
if dtype is None:
dtype = t.data.get_dtype() # type: ignore[attr-defined]
elif dtype != t.data.get_dtype():
return True
return False

# group by device, whether any of the inputs are dynamic, and whether their types match
# (proxy for type promotion)
# Note: we'll fallback on type promotion until
# https://github.com/openai/triton/commit/9820899b3845e461d9031dba66062efade65d420
# is in the pytorch triton version
def group_args(arg_pairs):
out = defaultdict(list)
for i, args in enumerate(arg_pairs):
use_foreach = not (is_dynamic(*args) or has_type_promotion(*args))
use_foreach = not is_dynamic(*args)
device = None
for t in args:
if isinstance(t, TensorBox):
Expand Down

0 comments on commit 2d841bc

Please sign in to comment.