diff --git a/test/inductor/test_cpu_repro.py b/test/inductor/test_cpu_repro.py index 6ce94b10b45eab..f2366781b66890 100644 --- a/test/inductor/test_cpu_repro.py +++ b/test/inductor/test_cpu_repro.py @@ -4,6 +4,7 @@ import functools import itertools import math +import os import platform import sys import unittest @@ -60,12 +61,16 @@ check_model = test_torchinductor.check_model requires_vectorization = unittest.skipUnless( - cpu_vec_isa.valid_vec_isa_list(), "Does not support vectorization" + cpu_vec_isa.valid_vec_isa_list() and os.getenv("ATEN_CPU_CAPABILITY") != "default", + "Does not support vectorization", ) def check_metrics_vec_kernel_count(num_expected_vec_kernels): - if cpu_vec_isa.valid_vec_isa_list(): + if ( + cpu_vec_isa.valid_vec_isa_list() + and os.getenv("ATEN_CPU_CAPABILITY") != "default" + ): assert metrics.generated_cpp_vec_kernel_count == num_expected_vec_kernels @@ -1586,6 +1591,78 @@ def fn(x): metrics.reset() self.common(fn, (value,)) + @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") + @unittest.skipIf( + not cpu_vec_isa.valid_vec_isa_list() + or "avx2" in [str(vec_isa) for vec_isa in cpu_vec_isa.valid_vec_isa_list()], + "Does not support vectorization or not s390x/aarch64/ppc64le machine", + ) + @patch("torch.cuda.is_available", lambda: False) + def test_auto_zvec_neon_vsx_simd(self): + vec_zvec_neon_vsx = cpu_vec_isa.valid_vec_isa_list()[0] + self.assertTrue(vec_zvec_neon_vsx.bit_width() == 256) + + with config.patch({"cpp.simdlen": 0}): + isa = cpu_vec_isa.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": 1}): + isa = cpu_vec_isa.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": 257}): + isa = cpu_vec_isa.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": 256}): + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon_vsx) + + pre_var = os.getenv("ATEN_CPU_CAPABILITY") + if pre_var: + os.environ.pop("ATEN_CPU_CAPABILITY") + + try: + with config.patch({"cpp.simdlen": None}): + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon_vsx) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "avx2" + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon_vsx) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "avx512" + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon_vsx) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "default" + isa = cpu_vec_isa.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "neon" + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon_vsx) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "zvector" + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon_vsx) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "vsx" + isa = cpu_vec_isa.pick_vec_isa() + self.assertTrue(isa == vec_zvec_neon_vsx) + + finally: + if pre_var: + os.environ["ATEN_CPU_CAPABILITY"] = pre_var + elif os.getenv("ATEN_CPU_CAPABILITY"): + os.environ.pop("ATEN_CPU_CAPABILITY") + @unittest.skipIf(IS_FBCODE, "Not yet runnable in fbcode") @unittest.skipIf( platform.machine() != "x86_64" or not cpu_vec_isa.valid_vec_isa_list(), @@ -1606,15 +1683,6 @@ def test_auto_simd(self): self.assertTrue(vec_avx512.nelements(torch.bfloat16) == 32) self.assertTrue(vec_avx2.nelements(torch.bfloat16) == 16) - with config.patch({"cpp.simdlen": None}): - isa = cpu_vec_isa.pick_vec_isa() - if vec_amx in cpu_vec_isa.valid_vec_isa_list(): - self.assertTrue(isa == vec_amx) - elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): - self.assertTrue(isa == vec_avx512) - else: - self.assertTrue(isa == vec_avx2) - with config.patch({"cpp.simdlen": 0}): isa = cpu_vec_isa.pick_vec_isa() self.assertFalse(isa) @@ -1646,6 +1714,81 @@ def test_auto_simd(self): isa = cpu_vec_isa.pick_vec_isa() self.assertTrue(isa == vec_avx2) + pre_var = os.getenv("ATEN_CPU_CAPABILITY") + if pre_var: + os.environ.pop("ATEN_CPU_CAPABILITY") + + try: + with config.patch({"cpp.simdlen": None}): + isa = cpu_vec_isa.pick_vec_isa() + if vec_amx in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_amx) + elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "avx2" + isa = cpu_vec_isa.pick_vec_isa() + if vec_amx in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx2) + elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx2) + elif vec_avx2 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "avx512" + isa = cpu_vec_isa.pick_vec_isa() + if vec_amx in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_amx) + elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "default" + isa = cpu_vec_isa.pick_vec_isa() + self.assertFalse(isa) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "neon" + isa = cpu_vec_isa.pick_vec_isa() + if vec_amx in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_amx) + elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "zvector" + isa = cpu_vec_isa.pick_vec_isa() + if vec_amx in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_amx) + elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + with config.patch({"cpp.simdlen": None}): + os.environ["ATEN_CPU_CAPABILITY"] = "vsx" + isa = cpu_vec_isa.pick_vec_isa() + if vec_amx in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_amx) + elif vec_avx512 in cpu_vec_isa.valid_vec_isa_list(): + self.assertTrue(isa == vec_avx512) + else: + self.assertTrue(isa == vec_avx2) + + finally: + if pre_var: + os.environ["ATEN_CPU_CAPABILITY"] = pre_var + elif os.getenv("ATEN_CPU_CAPABILITY"): + os.environ.pop("ATEN_CPU_CAPABILITY") + @requires_vectorization @patch("torch.cuda.is_available", lambda: False) def test_masked_fill_softmax(self): @@ -2626,6 +2769,7 @@ def fn(x, y): 1, ) + @requires_vectorization def test_argmin(self): def fn(x): return torch.argmin(x, -1) @@ -2637,6 +2781,7 @@ def fn(x): self.common(fn, (x,)) assert metrics.generated_cpp_vec_kernel_count == 1 + @requires_vectorization def test_argmax_argmin_with_nan_value(self): def fn(x): return torch.argmax(x) @@ -3521,6 +3666,7 @@ def forward(self, idx, x): self.common(m, (idx, x)) check_metrics_vec_kernel_count(1) + @requires_vectorization def test_embedding_vec_bf16(self): class M(torch.nn.Module): def __init__(self) -> None: @@ -3862,7 +4008,7 @@ def fn(x): x = torch.randint(0, 100, (819,), dtype=torch.int64) metrics.reset() self.common(fn, (x,)) - assert metrics.generated_cpp_vec_kernel_count == 1 + check_metrics_vec_kernel_count(1) def test_highp_to_lowp_cse_var_cache_with_store(self): # Fix issue: https://github.com/pytorch/pytorch/issues/128263 @@ -3896,7 +4042,7 @@ def fn(x): x = torch.randint(0, 100, (22, 51), dtype=torch.int64) metrics.reset() self.common(fn, (x,)) - assert metrics.generated_cpp_vec_kernel_count == 1 + check_metrics_vec_kernel_count(1) @config.patch({"cpp.dynamic_threads": True}) def test_reduction_with_dynamic_threads(self): @@ -4007,6 +4153,7 @@ def fn(arg0_1, arg0_2): exactly=True, ).run(code) + @requires_vectorization def test_repeated_exp(self): def fn(x): y = x.sigmoid() @@ -4035,6 +4182,7 @@ def fn(x): self.common(fn, (x,)) check_metrics_vec_kernel_count(1) + @requires_vectorization def test_consistent_remove_buffers(self): def fn(x): z = x + x diff --git a/test/inductor/test_extension_backend.py b/test/inductor/test_extension_backend.py index 6f972e46a1d987..c550c8e3b13d89 100644 --- a/test/inductor/test_extension_backend.py +++ b/test/inductor/test_extension_backend.py @@ -148,7 +148,10 @@ def fn(a, b, c): metrics.reset() opt_fn = torch.compile()(fn) _, code = run_and_get_cpp_code(opt_fn, x, y, z) - if cpu_vec_isa.valid_vec_isa_list(): + if ( + cpu_vec_isa.valid_vec_isa_list() + and os.getenv("ATEN_CPU_CAPABILITY") != "default" + ): load_expr = "loadu" else: load_expr = " = in_ptr0[static_cast(i0)];" diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index e74bce595b9b16..4ce56c4ce3d676 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -99,7 +99,7 @@ importlib.import_module("functorch") importlib.import_module("filelock") -from torch._inductor import config, test_operators +from torch._inductor import config, cpu_vec_isa, test_operators from torch._inductor.compile_fx import ( compile_fx, compile_fx_inner, @@ -1523,10 +1523,16 @@ def test( pass # no device asserts in halide elif self.device == "cpu": _, code = run_and_get_cpp_code(fn_opt, *inps) - self.assertTrue((") ? (" in code or "blendv" in code) is has_wrapping) self.assertTrue(("TORCH_CHECK" in code) is has_assert) - # Assert that we always vectorize the kernel regardless of wrapping / checks - self.assertTrue(("loadu" in code) is vectorize) + if ( + cpu_vec_isa.valid_vec_isa_list() + and os.getenv("ATEN_CPU_CAPABILITY") != "default" + ): + self.assertTrue( + (") ? (" in code or "blendv" in code) is has_wrapping + ) + # Assert that we always vectorize the kernel regardless of wrapping / checks + self.assertTrue(("loadu" in code) is vectorize) else: code = run_and_get_triton_code(fn_opt, *inps) self.assertTrue(("tl.where" in code) is has_wrapping) @@ -1838,8 +1844,20 @@ def test_multilayer_var_lowp(self): def fn(a): return torch.var(a) - self.common(fn, (torch.rand((16, 16, 352, 352), dtype=torch.float16),)) - self.common(fn, (torch.rand((14923), dtype=torch.float16),)) + atol = None + rtol = None + if self.device == "cpu" and os.getenv("ATEN_CPU_CAPABILITY") == "default": + atol = 1e-3 + rtol = 1e-3 + self.common( + fn, + (torch.rand((16, 16, 352, 352), dtype=torch.float16),), + atol=atol, + rtol=rtol, + ) + self.common( + fn, (torch.rand((14923), dtype=torch.float16),), atol=atol, rtol=rtol + ) def test_split_cumsum(self): def fn(a): @@ -10103,9 +10121,15 @@ def fn(query, scores, window_overlap): if is_cpp_backend(self.device): opt_fn = torch._dynamo.optimize("inductor")(fn) _, code = run_and_get_cpp_code(opt_fn, *args) + num = ( + 2 + if cpu_vec_isa.valid_vec_isa_list() + and os.getenv("ATEN_CPU_CAPABILITY") != "default" + else 1 + ) FileCheck().check_count( "static_cast(256)", - 2, + num, exactly=True, ).run(code) diff --git a/torch/_inductor/cpu_vec_isa.py b/torch/_inductor/cpu_vec_isa.py index 344f2bc58f56ab..98fa37a7b68159 100644 --- a/torch/_inductor/cpu_vec_isa.py +++ b/torch/_inductor/cpu_vec_isa.py @@ -6,7 +6,8 @@ import re import subprocess import sys -from typing import Any, Callable, Dict, List +import warnings +from typing import Any, Callable, Dict, List, Union import torch from torch._inductor import config @@ -309,6 +310,35 @@ def _check_and_append_supported_isa( supported_vec_isa_list = [VecAMX(), VecAVX512(), VecAVX2(), VecNEON()] +def get_isa_from_cpu_capability( + capability: Union[str, None], + vec_isa_list: List[VecISA], + invalid_vec_isa: InvalidVecISA, +): + # AMX setting is not supported in eager + # VecAMX will be prioritized for selection when setting ATEN_CPU_CAPABILITY to avx512 + capability_to_isa_str = { + "default": "INVALID_VEC_ISA", + "neon": "asimd", + "zvector": "zvector", + "vsx": "vsx", + "avx2": "avx2", + "avx512": "avx512", + } + if capability in capability_to_isa_str.keys(): + isa_str = capability_to_isa_str[capability] + if isa_str == "INVALID_VEC_ISA": + return invalid_vec_isa + for vec_isa in vec_isa_list: + if isa_str in str(vec_isa): + return vec_isa + + if capability: + warnings.warn(f"ignoring invalid value for ATEN_CPU_CAPABILITY {capability}") + + return vec_isa_list[0] + + # Cache the cpuinfo to avoid I/O overhead. Meanwhile, the cpuinfo content # might have too much redundant content that is useless for ISA check. Hence, # we only cache some key isa information. @@ -359,10 +389,12 @@ def pick_vec_isa() -> VecISA: if not _valid_vec_isa_list: return invalid_vec_isa - # If the simdlen is None, it indicates determine the vectorization length automatically + # If the simdlen is None, set simdlen based on the environment ATEN_CPU_CAPABILITY + # to control CPU vec ISA if config.cpp.simdlen is None: - assert _valid_vec_isa_list - return _valid_vec_isa_list[0] + return get_isa_from_cpu_capability( + os.getenv("ATEN_CPU_CAPABILITY"), _valid_vec_isa_list, invalid_vec_isa + ) for isa in _valid_vec_isa_list: if config.cpp.simdlen == isa.bit_width():