Skip to content

Commit

Permalink
[halide-backend] Random number generation (pytorch#130211)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#130211
Approved by: https://github.com/jansel
  • Loading branch information
qqaatw authored and pytorchmergebot committed Jul 15, 2024
1 parent 1bc390c commit dc7725c
Show file tree
Hide file tree
Showing 4 changed files with 171 additions and 16 deletions.
40 changes: 39 additions & 1 deletion test/inductor/test_halide.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Owner(s): ["oncall: pt2"]
import functools
import itertools
import os
import sys
import textwrap
Expand All @@ -14,6 +16,7 @@

from torch.testing._internal.common_utils import IS_CI, IS_MACOS, IS_WINDOWS
from torch.testing._internal.inductor_utils import HAS_CPU
from torch.utils._triton import has_triton


if IS_WINDOWS and IS_CI:
Expand Down Expand Up @@ -43,7 +46,6 @@
"halide.scan_kernels": True,
"cpu_backend": "halide",
"cuda_backend": "halide",
"fallback_random": True, # TODO(jansel): support random
}
)

Expand Down Expand Up @@ -195,6 +197,42 @@ def generate(g):
fn(a, b, c)
self.assertEqual(c, a + b)

@unittest.skipUnless(has_triton(), "requires triton")
def test_random_consistency(self):
seed = 1234
shape = (3, 3)
dtype = torch.float32

for (rand_fn,) in itertools.product(
(
functools.partial(torch.rand, shape, dtype=dtype, device="cuda"),
functools.partial(torch.randn, shape, dtype=dtype, device="cuda"),
functools.partial(
torch.randint,
-1000,
1000,
size=shape,
dtype=torch.int64,
device="cuda",
),
)
):

@torch.compile(backend="inductor", options={"cuda_backend": "halide"})
def get_rand_halide():
return rand_fn()

@torch.compile(backend="inductor", options={"cuda_backend": "triton"})
def get_rand_triton():
return rand_fn()

torch.manual_seed(seed)
halide_output = get_rand_halide()
torch.manual_seed(seed)
triton_output = get_rand_triton()

self.assertEqual(halide_output, triton_output)


if test_torchinductor.HAS_CPU and HAS_HALIDE:
SweepInputsCpuHalideTest = make_halide(test_torchinductor.SweepInputsCpuTest)
Expand Down
20 changes: 9 additions & 11 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6977,7 +6977,6 @@ def fn(a):
],
)

@skip_if_halide # rng
def test_bernoulli2(self):
def fn(a):
return aten.bernoulli(a)
Expand Down Expand Up @@ -8291,7 +8290,6 @@ def fn(a):
result = fn(torch.randn([1, 2, 16, 4]).requires_grad_())
result.sum().backward()

@skip_if_halide # rand
def test_dropout2(self):
n = 100000
weight = torch.ones(
Expand Down Expand Up @@ -8330,7 +8328,10 @@ def check(r, g):
torch.manual_seed(1234)
weight.grad.zero_()
r2, (fw_code, bw_code) = run_fw_bw_and_get_code(lambda: run(ones))
if self.device == GPU_TYPE:
if is_halide_backend(self.device):
self.assertEqual(fw_code.count("halide_helpers.rand"), 1)
self.assertEqual(bw_code.count("halide_helpers.rand"), 0)
elif self.device == GPU_TYPE:
self.assertEqual(fw_code.count("tl.rand"), 1)
self.assertEqual(bw_code.count("tl.rand"), 0)
g2 = weight.grad.clone()
Expand All @@ -8348,7 +8349,6 @@ def check(r, g):
self.assertTrue(same(g2, g3))

@config.patch(search_autotune_cache=False)
@skip_if_halide # rand
def test_dropout3(self):
m = torch.nn.Sequential(
torch.nn.Linear(32, 32, bias=False),
Expand All @@ -8367,16 +8367,14 @@ def run(x):
lambda: run(torch.randn([8, 32], device=self.device))
)

if self.device == GPU_TYPE:
if is_halide_backend(self.device):
self.assertEqual(fw_code.count("halide_helpers.rand"), 2)
self.assertEqual(bw_code.count("halide_helpers.rand"), 0)
elif self.device == GPU_TYPE:
self.assertEqual(fw_code.count("tl.rand"), 2)
self.assertEqual(bw_code.count("tl.rand"), 0)
expected_kernel = 4

self.assertEqual(
torch._inductor.metrics.generated_kernel_count, expected_kernel
)
self.assertEqual(torch._inductor.metrics.generated_kernel_count, 4)

@skip_if_halide # rand
def test_randint_kernel_count(self):
@torch._dynamo.optimize_assert("inductor")
def fn1():
Expand Down
9 changes: 5 additions & 4 deletions torch/_inductor/codegen/halide.py
Original file line number Diff line number Diff line change
Expand Up @@ -404,19 +404,19 @@ def bitwise_right_shift(a, b):

@staticmethod
def rand(seed, offset):
raise Unsupported("rand")
return f"halide_helpers.rand({seed}, {offset})"

@staticmethod
def randn(seed, offset):
raise Unsupported("rand")
return f"halide_helpers.randn({seed}, {offset})"

@staticmethod
def randint64(seed, offset, low, high):
raise Unsupported("rand")
return f"halide_helpers.randint64({seed}, {offset}, {low}, {high})"

@staticmethod
def load_seed(name, offset):
raise Unsupported("rand")
return f"{ops.load(name, 0)} + {V.kernel.args.seed_offset('load_seed_offset', offset)}"

@staticmethod
def rsqrt(x):
Expand Down Expand Up @@ -1491,6 +1491,7 @@ def codegen_kernel(self, name=None):
code.splice(
"""
import halide as hl
from torch._inductor.runtime import halide_helpers
from math import inf, nan
@hl.generator(name="kernel")
Expand Down
118 changes: 118 additions & 0 deletions torch/_inductor/runtime/halide_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
# mypy: allow-untyped-defs
try:
import halide as hl # type: ignore[import-untyped, import-not-found]
except ImportError:
hl = None

PHILOX_N_ROUNDS_DEFAULT = 10 # Default number of rounds for philox

if hl is not None:
PHILOX_KEY_A_U32 = hl.u32(0x9E3779B9)
PHILOX_KEY_B_U32 = hl.u32(0xBB67AE85)
PHILOX_ROUND_A_U32 = hl.u32(0xD2511F53)
PHILOX_ROUND_B_U32 = hl.u32(0xCD9E8D57)
else:
PHILOX_KEY_A_U32 = None
PHILOX_KEY_B_U32 = None
PHILOX_ROUND_A_U32 = None
PHILOX_ROUND_B_U32 = None


def _pair_uniform_to_normal(u1, u2):
"""Box-Muller transform"""
u1 = hl.max(hl.f32(1.0e-7), u1)
th = hl.f32(6.283185307179586) * u2
r = hl.sqrt(hl.f32(-2.0) * hl.log(u1))
return r * hl.cos(th), r * hl.sin(th)


def _uint_to_uniform_float(x):
"""
Numerically stable function to convert a random uint into a random float uniformly sampled in [0, 1).
"""

# TODO:
# conditions can be simplified
# scale is ((2**23 - 1) / 2**23) * 2**(N_BITS - 1)
# https://github.com/triton-lang/triton/blob/e4a0d93ff1a367c7d4eeebbcd7079ed267e6b06f/python/triton/language/random.py#L116-L132.
assert x.type() == hl.UInt(32) or x.type() == hl.Int(32)
x = hl.cast(hl.Int(32), x)
scale = hl.f64(4.6566127342e-10)
x = hl.select(x < 0, -x - 1, x)
return x * scale


def philox_impl(c0, c1, c2, c3, k0, k1, n_rounds):
def umulhi(a, b):
a = hl.cast(hl.UInt(64), a)
b = hl.cast(hl.UInt(64), b)
return hl.cast(hl.UInt(32), ((a * b) >> 32) & hl.u64(0xFFFFFFFF))

for _ in range(n_rounds):
_c0, _c2 = c0, c2

c0 = umulhi(PHILOX_ROUND_B_U32, _c2) ^ c1 ^ k0
c2 = umulhi(PHILOX_ROUND_A_U32, _c0) ^ c3 ^ k1
c1 = PHILOX_ROUND_B_U32 * _c2
c3 = PHILOX_ROUND_A_U32 * _c0
# raise key
k0 = k0 + PHILOX_KEY_A_U32
k1 = k1 + PHILOX_KEY_B_U32

return c0, c1, c2, c3


def halide_philox(seed, c0, c1, c2, c3, n_rounds):
seed = hl.cast(hl.UInt(64), seed)

assert c0.type().bits() == 32

seed_hi = hl.cast(hl.UInt(32), (seed >> 32) & hl.u64(0xFFFFFFFF))
seed_lo = hl.cast(hl.UInt(32), seed & hl.u64(0xFFFFFFFF))

return philox_impl(c0, c1, c2, c3, seed_lo, seed_hi, n_rounds)


def randint4x(seed, offset, n_rounds):
offset = hl.cast(hl.UInt(32), offset)
_0 = hl.u32(0)
return halide_philox(seed, offset, _0, _0, _0, n_rounds)


def rand4x(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT):
i1, i2, i3, i4 = randint4x(seed, offset, n_rounds)
u1 = _uint_to_uniform_float(i1)
u2 = _uint_to_uniform_float(i2)
u3 = _uint_to_uniform_float(i3)
u4 = _uint_to_uniform_float(i4)
return u1, u2, u3, u4


def randint(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT):
ret, _, _, _ = randint4x(seed, offset, n_rounds)
return ret


def rand(seed, offset, n_rounds=PHILOX_N_ROUNDS_DEFAULT):
source = randint(seed, offset, n_rounds)
return _uint_to_uniform_float(source)


def randn(seed, offset):
i1, i2, _, _ = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT)
u1 = _uint_to_uniform_float(i1)
u2 = _uint_to_uniform_float(i2)
n1, _ = _pair_uniform_to_normal(u1, u2)
return n1


def randint64(seed, offset, low, high):
r0, r1, r2, r3 = randint4x(seed, offset, PHILOX_N_ROUNDS_DEFAULT)
r0 = hl.cast(hl.UInt(64), r0)
r1 = hl.cast(hl.UInt(64), r1)

result = r0 | (r1 << 32)
size = high - low
result = result % hl.cast(hl.UInt(64), size)
result = hl.cast(hl.Int(64), result) + low
return result

0 comments on commit dc7725c

Please sign in to comment.