Skip to content

Commit

Permalink
Add verbose test
Browse files Browse the repository at this point in the history
Signed-off-by: Stanley Winata <[email protected]>
  • Loading branch information
raikonenfnu committed Feb 4, 2025
1 parent 5531c94 commit 2fc4f1c
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions lit_tests/kernel/wave/codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1740,6 +1740,49 @@ def cmp_lowerings(
# CHECK: arith.select


@run_test
def test_verbose_int_comparisons():
constraints: list[tkw.Constraint] = [
tkw.HardwareConstraint(
threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16}
)
]
constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)]
constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)]
constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)]
constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)]

@tkw.wave(constraints)
def verbose_cmp_lowerings(
a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32],
b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32],
):
a_reg = tkw.read(a, elements_per_thread=4)
b_reg = tkw.read(b, elements_per_thread=4)
sgt = tkw.gt(a_reg, b_reg)
s1 = tkw.select(sgt, a_reg, b_reg)
slt = tkw.lt(a_reg, b_reg)
s2 = tkw.select(slt, a_reg, b_reg)
sge = tkw.ge(s1, s2)
s3 = tkw.select(sge, s1, s2)
sle = tkw.le(s1, s2)
s4 = tkw.select(sle, s1, s2)
res = s1 + s2 + s3 + s4
tkw.write(res, a, elements_per_thread=4)

a = torch.randint(42, (16, 16), dtype=torch.int32)
b = torch.randint(42, (16, 16), dtype=torch.int32)
with codegen_test_context():
print(verbose_cmp_lowerings(a, b).module_op)
# CHECK-LABEL: @verbose_cmp_lowerings
# CHECK: arith.cmpi sgt
# CHECK: arith.select
# CHECK: arith.cmpi slt
# CHECK: arith.select
# CHECK: arith.cmpi sge
# CHECK: arith.select


# TODO: Something is broken in codegen and we are getting int in place of fx.Node
# @launch
@pytest.mark.skip(reason="getitem: Currently only stub implementation")
Expand Down

0 comments on commit 2fc4f1c

Please sign in to comment.