diff --git a/test/inductor/test_halide.py b/test/inductor/test_halide.py index e470c06931b20..fdb7846561ebe 100644 --- a/test/inductor/test_halide.py +++ b/test/inductor/test_halide.py @@ -1,4 +1,6 @@ # Owner(s): ["oncall: pt2"] +import functools +import itertools import os import sys import textwrap @@ -14,6 +16,7 @@ from torch.testing._internal.common_utils import IS_CI, IS_MACOS, IS_WINDOWS from torch.testing._internal.inductor_utils import HAS_CPU +from torch.utils._triton import has_triton if IS_WINDOWS and IS_CI: @@ -43,7 +46,6 @@ "halide.scan_kernels": True, "cpu_backend": "halide", "cuda_backend": "halide", - "fallback_random": True, # TODO(jansel): support random } ) @@ -195,6 +197,42 @@ def generate(g): fn(a, b, c) self.assertEqual(c, a + b) + @unittest.skipUnless(has_triton(), "requires triton") + def test_random_consistency(self): + seed = 1234 + shape = (3, 3) + dtype = torch.float32 + + for (rand_fn,) in itertools.product( + ( + functools.partial(torch.rand, shape, dtype=dtype, device="cuda"), + functools.partial(torch.randn, shape, dtype=dtype, device="cuda"), + functools.partial( + torch.randint, + -1000, + 1000, + size=shape, + dtype=torch.int64, + device="cuda", + ), + ) + ): + + @torch.compile(backend="inductor", options={"cuda_backend": "halide"}) + def get_rand_halide(): + return rand_fn() + + @torch.compile(backend="inductor", options={"cuda_backend": "triton"}) + def get_rand_triton(): + return rand_fn() + + torch.manual_seed(seed) + halide_output = get_rand_halide() + torch.manual_seed(seed) + triton_output = get_rand_triton() + + self.assertEqual(halide_output, triton_output) + if test_torchinductor.HAS_CPU and HAS_HALIDE: SweepInputsCpuHalideTest = make_halide(test_torchinductor.SweepInputsCpuTest) diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index ed6773a0ba6e0..ec5fbace54985 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6977,7 +6977,6 @@ def fn(a): ], ) - @skip_if_halide # rng def test_bernoulli2(self): def fn(a): return aten.bernoulli(a) @@ -8291,7 +8290,6 @@ def fn(a): result = fn(torch.randn([1, 2, 16, 4]).requires_grad_()) result.sum().backward() - @skip_if_halide # rand def test_dropout2(self): n = 100000 weight = torch.ones( @@ -8330,7 +8328,10 @@ def check(r, g): torch.manual_seed(1234) weight.grad.zero_() r2, (fw_code, bw_code) = run_fw_bw_and_get_code(lambda: run(ones)) - if self.device == GPU_TYPE: + if is_halide_backend(self.device): + self.assertEqual(fw_code.count("halide_helpers.rand"), 1) + self.assertEqual(bw_code.count("halide_helpers.rand"), 0) + elif self.device == GPU_TYPE: self.assertEqual(fw_code.count("tl.rand"), 1) self.assertEqual(bw_code.count("tl.rand"), 0) g2 = weight.grad.clone() @@ -8348,7 +8349,6 @@ def check(r, g): self.assertTrue(same(g2, g3)) @config.patch(search_autotune_cache=False) - @skip_if_halide # rand def test_dropout3(self): m = torch.nn.Sequential( torch.nn.Linear(32, 32, bias=False), @@ -8367,16 +8367,14 @@ def run(x): lambda: run(torch.randn([8, 32], device=self.device)) ) - if self.device == GPU_TYPE: + if is_halide_backend(self.device): + self.assertEqual(fw_code.count("halide_helpers.rand"), 2) + self.assertEqual(bw_code.count("halide_helpers.rand"), 0) + elif self.device == GPU_TYPE: self.assertEqual(fw_code.count("tl.rand"), 2) self.assertEqual(bw_code.count("tl.rand"), 0) - expected_kernel = 4 - - self.assertEqual( - torch._inductor.metrics.generated_kernel_count, expected_kernel - ) + self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4) - @skip_if_halide # rand def test_randint_kernel_count(self): @torch._dynamo.optimize_assert("inductor") def fn1(): diff --git a/torch/_inductor/codegen/halide.py b/torch/_inductor/codegen/halide.py index 44cf76cc799be..acf0ee0e957d7 100644 --- a/torch/_inductor/codegen/halide.py +++ b/torch/_inductor/codegen/halide.py @@ -404,19 +404,19 @@ def bitwise_right_shift(a, b): @staticmethod def rand(seed, offset): - raise Unsupported("rand") + return f"halide_helpers.rand({seed}, {offset})" @staticmethod def randn(seed, offset): - raise Unsupported("rand") + return f"halide_helpers.randn({seed}, {offset})" @staticmethod def randint64(seed, offset, low, high): - raise Unsupported("rand") + return f"halide_helpers.randint64({seed}, {offset}, {low}, {high})" @staticmethod def load_seed(name, offset): - raise Unsupported("rand") + return f"{ops.load(name, 0)} + {V.kernel.args.seed_offset('load_seed_offset', offset)}" @staticmethod def rsqrt(x): @@ -1491,6 +1491,7 @@ def codegen_kernel(self, name=None): code.splice( """ import halide as hl + from torch._inductor.runtime import halide_helpers from math import inf, nan @hl.generator(name="kernel") diff --git a/torch/_inductor/runtime/halide_helpers.py b/torch/_inductor/runtime/halide_helpers.py new file mode 100644 index 0000000000000..db813b0c051a1 --- /dev/null +++ b/torch/_inductor/runtime/halide_helpers.py @@ -0,0 +1,118 @@ +# mypy: allow-untyped-defs +try: + import halide as hl # type: ignore[import-untyped, import-not-found] +except ImportError: + hl = None + +PHILOX_N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox + +if hl is not None: + PHILOX_KEY_A_U32 = hl.u32(0x9E3779B9) + PHILOX_KEY_B_U32 = hl.u32(0xBB67AE85) + PHILOX_ROUND_A_U32 = hl.u32(0xD2511F53) + PHILOX_ROUND_B_U32 = hl.u32(0xCD9E8D57) +else: + PHILOX_KEY_A_U32 = None + PHILOX_KEY_B_U32 = None + PHILOX_ROUND_A_U32 = None + PHILOX_ROUND_B_U32 = None + + +def _pair_uniform_to_normal(u1, u2): + """Box-Muller transform""" + u1 = hl.max(hl.f32(1.0e-7), u1) + th = hl.f32(6.283185307179586) * u2 + r = hl.sqrt(hl.f32(-2.0) * hl.log(u1)) + return r * hl.cos(th), r * hl.sin(th) + + +def _uint_to_uniform_float(x): + """ + Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1). + """ + + # TODO: + # conditions can be simplified + # scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1) + # https://github.com/triton-lang/triton/blob/e4a0d93ff1a367c7d4eeebbcd7079ed267e6b06f/python/triton/language/random.py#L116-L132. + assert x.type() == hl.UInt(32) or x.type() == hl.Int(32) + x = hl.cast(hl.Int(32), x) + scale = hl.f64(4.6566127342e-10) + x = hl.select(x < 0, -x - 1, x) + return x * scale + + +def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds): + def umulhi(a, b): + a = hl.cast(hl.UInt(64), a) + b = hl.cast(hl.UInt(64), b) + return hl.cast(hl.UInt(32), ((a * b) >> 32) & hl.u64(0xFFFFFFFF)) + + for _ in range(n_rounds): + _c0, _c2 = c0, c2 + + c0 = umulhi(PHILOX_ROUND_B_U32, _c2) ^ c1 ^ k0 + c2 = umulhi(PHILOX_ROUND_A_U32, _c0) ^ c3 ^ k1 + c1 = PHILOX_ROUND_B_U32 * _c2 + c3 = PHILOX_ROUND_A_U32 * _c0 + # raise key + k0 = k0 + PHILOX_KEY_A_U32 + k1 = k1 + PHILOX_KEY_B_U32 + + return c0, c1, c2, c3 + + +def halide_philox(seed, c0, c1, c2, c3, n_rounds): + seed = hl.cast(hl.UInt(64), seed) + + assert c0.type().bits() == 32 + + seed_hi = hl.cast(hl.UInt(32), (seed >> 32) & hl.u64(0xFFFFFFFF)) + seed_lo = hl.cast(hl.UInt(32), seed & hl.u64(0xFFFFFFFF)) + + return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds) + + +def randint4x(seed, offset, n_rounds): + offset = hl.cast(hl.UInt(32), offset) + _0 = hl.u32(0) + return halide_philox(seed, offset, _0, _0, _0, n_rounds) + + +def rand4x(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + i1, i2, i3, i4 = randint4x(seed, offset, n_rounds) + u1 = _uint_to_uniform_float(i1) + u2 = _uint_to_uniform_float(i2) + u3 = _uint_to_uniform_float(i3) + u4 = _uint_to_uniform_float(i4) + return u1, u2, u3, u4 + + +def randint(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + ret, _, _, _ = randint4x(seed, offset, n_rounds) + return ret + + +def rand(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT): + source = randint(seed, offset, n_rounds) + return _uint_to_uniform_float(source) + + +def randn(seed, offset): + i1, i2, _, _ = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) + u1 = _uint_to_uniform_float(i1) + u2 = _uint_to_uniform_float(i2) + n1, _ = _pair_uniform_to_normal(u1, u2) + return n1 + + +def randint64(seed, offset, low, high): + r0, r1, r2, r3 = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT) + r0 = hl.cast(hl.UInt(64), r0) + r1 = hl.cast(hl.UInt(64), r1) + + result = r0 | (r1 << 32) + size = high - low + result = result % hl.cast(hl.UInt(64), size) + result = hl.cast(hl.Int(64), result) + low + return result