diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index 1af20b4f..9f8a0c6c 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -1740,6 +1740,49 @@ def cmp_lowerings( # CHECK: arith.select +@run_test +def test_verbose_int_comparisons(): + constraints: list[tkw.Constraint] = [ + tkw.HardwareConstraint( + threads_per_wave=64, waves_per_block=(1, 1, 1), vector_shapes={M: 16, N: 16} + ) + ] + constraints += [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + @tkw.wave(constraints) + def verbose_cmp_lowerings( + a: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + b: tkl.Memory[M, N, ADDRESS_SPACE, tkl.i32], + ): + a_reg = tkw.read(a, elements_per_thread=4) + b_reg = tkw.read(b, elements_per_thread=4) + sgt = tkw.gt(a_reg, b_reg) + s1 = tkw.select(sgt, a_reg, b_reg) + slt = tkw.lt(a_reg, b_reg) + s2 = tkw.select(slt, a_reg, b_reg) + sge = tkw.ge(s1, s2) + s3 = tkw.select(sge, s1, s2) + sle = tkw.le(s1, s2) + s4 = tkw.select(sle, s1, s2) + res = s1 + s2 + s3 + s4 + tkw.write(res, a, elements_per_thread=4) + + a = torch.randint(42, (16, 16), dtype=torch.int32) + b = torch.randint(42, (16, 16), dtype=torch.int32) + with codegen_test_context(): + print(verbose_cmp_lowerings(a, b).module_op) + # CHECK-LABEL: @verbose_cmp_lowerings + # CHECK: arith.cmpi sgt + # CHECK: arith.select + # CHECK: arith.cmpi slt + # CHECK: arith.select + # CHECK: arith.cmpi sge + # CHECK: arith.select + + # TODO: Something is broken in codegen and we are getting int in place of fx.Node # @launch @pytest.mark.skip(reason="getitem: Currently only stub implementation")