diff --git a/lit_tests/kernel/wave/codegen.py b/lit_tests/kernel/wave/codegen.py index b84cc2714..8d67d42e0 100644 --- a/lit_tests/kernel/wave/codegen.py +++ b/lit_tests/kernel/wave/codegen.py @@ -606,6 +606,339 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: # CHECK: return +@run_test +def test_gemm_pipelined(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2)] + + constraints += [ + tkw.HardwareConstraint( + threads_per_wave=64, + waves_per_block=(2, 2, 1), + mma_type=tkw.MMAType.F32_16x16x16_F16, + ) + ] + + @tkw.wave(constraints) + def gemm_pipelined( + a: tkl.Memory[M, K, ADDRESS_SPACE, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE_0, tkl.f32], + ): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=LOAD_ELEMS_PER_THREAD) + b_reg = tkw.read(b, elements_per_thread=LOAD_ELEMS_PER_THREAD) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=STORE_ELEMS_PER_THREAD) + + with tk.gen.TestLaunchContext( + { + M: 128, + N: 128, + K: 128, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + LOAD_ELEMS_PER_THREAD: 4, + STORE_ELEMS_PER_THREAD: 4, + ADDRESS_SPACE: SHARED_ADDRESS_SPACE, + ADDRESS_SPACE_0: GLOBAL_ADDRESS_SPACE, + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, + }, + canonicalize=True, + schedule=True, + ): + a = torch.randn(64, 32, dtype=torch.float16) + b = torch.randn(128, 32, dtype=torch.float16) + c = torch.zeros(64, 128, dtype=torch.float32) + print(gemm_pipelined(a, b, c).module_op) + + # CHECK: func.func @gemm_pipelined(%[[ARG0:[a-zA-Z0-9_]+]]: !stream.binding, %[[ARG1:[a-zA-Z0-9_]+]]: + # CHECK-SAME: !stream.binding, %[[ARG2:[a-zA-Z0-9_]+]]: !stream.binding) attributes {translation_info = + # CHECK-SAME: #[[TRANSLATION:.+]]} { + # CHECK-DAG: %[[C19:.+]] = arith.constant 19 : index + # CHECK-DAG: %[[C18:.+]] = arith.constant 18 : index + # CHECK-DAG: %[[C17:.+]] = arith.constant 17 : index + # CHECK-DAG: %[[C3:.+]] = arith.constant 3 : index + # CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index + # CHECK-DAG: %[[C16:.+]] = arith.constant 16 : index + # CHECK-DAG: %[[C8:.+]] = arith.constant 8 : index + # CHECK-DAG: %[[C4:.+]] = arith.constant 4 : index + # CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index + # CHECK-DAG: %[[C32:.+]] = arith.constant 32 : index + # CHECK-DAG: %[[C64:.+]] = arith.constant 64 : index + # CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index + # CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<4xf32> + # CHECK: %[[WORKGROUP_ID_0:.+]] = stream.dispatch.workgroup.id[0] : index + # CHECK: %[[WORKGROUP_ID_1:.+]] = stream.dispatch.workgroup.id[1] : index + # CHECK-DAG: %[[THREAD_ID_X:.+]] = gpu.thread_id x + # CHECK-DAG: %[[THREAD_ID_Y:.+]] = gpu.thread_id y + # CHECK: %[[ALLOC:.+]] = memref.alloc() : memref<64x32xf16, #[[GPU:.+]].address_space> + # CHECK: %[[ALLOC_0:.+]] = memref.alloc() : memref<64x32xf16, #[[GPU]].address_space> + # CHECK: %[[D0:.+]] = stream.binding.subspan %[[ARG0]][%[[C0]]] : !stream.binding -> memref<128x128xf16, + # CHECK-SAME: strided<[128, 1], offset: ?>> + # CHECK: %[[D1:.+]] = arith.muli %[[WORKGROUP_ID_0]], %[[C64]] : index + # CHECK: %[[D2:.+]] = arith.muli %[[THREAD_ID_Y]], %[[C32]] : index + # CHECK: %[[D3:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C4]] : index + # CHECK: %[[D4:.+]] = arith.addi %[[D3]], %[[D2]] : index + # CHECK: %[[D5:.+]] = arith.remsi %[[D4]], %[[C64]] : index + # CHECK: %[[D6:.+]] = arith.addi %[[D5]], %[[D1]] : index + # CHECK: %[[D7:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C4]] : index + # CHECK: %[[D8:.+]] = arith.muli %[[D7]], %[[C8]] : index + # CHECK: %[[D9:.+]] = vector.load %[[D0]][%[[D6]], %[[D8]]] : memref<128x128xf16, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<8xf16> + # CHECK: %[[D10:.+]] = stream.binding.subspan %[[ARG1]][%[[C0]]] : !stream.binding -> memref<128x128xf16, + # CHECK-SAME: strided<[128, 1], offset: ?>> + # CHECK: %[[D11:.+]] = arith.muli %[[WORKGROUP_ID_1]], %[[C64]] : index + # CHECK: %[[D12:.+]] = arith.addi %[[D5]], %[[D11]] : index + # CHECK: %[[D13:.+]] = vector.load %[[D10]][%[[D12]], %[[D8]]] : memref<128x128xf16, strided<[128, 1], + # CHECK-SAME: offset: ?>>, vector<8xf16> + # CHECK: vector.store %[[D9]], %[[ALLOC]][%[[D5]], %[[D8]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<8xf16> + # CHECK: vector.store %[[D13]], %[[ALLOC_0]][%[[D5]], %[[D8]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<8xf16> + # CHECK: amdgpu.lds_barrier + # CHECK: %[[D14:.+]] = arith.divsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D15:.+]] = arith.muli %[[D14]], %[[C32]] : index + # CHECK: %[[D16:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C16]] : index + # CHECK: %[[D17:.+]] = arith.addi %[[D16]], %[[D15]] : index + # CHECK: %[[D18:.+]] = arith.remsi %[[THREAD_ID_X]], %[[C64]] : index + # CHECK: %[[D19:.+]] = arith.divsi %[[D18]], %[[C16]] : index + # CHECK: %[[D20:.+]] = arith.muli %[[D19]], %[[C4]] : index + # CHECK: %[[D21:.+]] = arith.addi %[[D20]], %[[C16]] : index + # CHECK: %[[D22:.+]] = vector.load %[[ALLOC]][%[[D17]], %[[D21]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D23:.+]] = arith.addi %[[D16]], %[[D2]] : index + # CHECK: %[[D24:.+]] = vector.load %[[ALLOC_0]][%[[D23]], %[[D21]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D25:.+]] = arith.addi %[[D23]], %[[C16]] : index + # CHECK: %[[D26:.+]] = vector.load %[[ALLOC_0]][%[[D25]], %[[D20]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D27:.+]] = vector.load %[[ALLOC_0]][%[[D25]], %[[D21]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D28:.+]] = arith.addi %[[D8]], %[[C32]] : index + # CHECK: %[[D29:.+]] = vector.load %[[D0]][%[[D6]], %[[D28]]] : memref<128x128xf16, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<8xf16> + # CHECK: %[[D30:.+]] = vector.load %[[D10]][%[[D12]], %[[D28]]] : memref<128x128xf16, strided<[128, 1], + # CHECK-SAME: offset: ?>>, vector<8xf16> + # CHECK: %[[D31:.+]] = vector.load %[[ALLOC]][%[[D17]], %[[D20]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D32:.+]] = arith.addi %[[D17]], %[[C16]] : index + # CHECK: %[[D33:.+]] = vector.load %[[ALLOC]][%[[D32]], %[[D20]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D34:.+]] = vector.load %[[ALLOC]][%[[D32]], %[[D21]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D35:.+]] = vector.load %[[ALLOC_0]][%[[D23]], %[[D20]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D36:.+]] = amdgpu.mfma %[[D31]] * %[[D35]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D37:.+]] = amdgpu.mfma %[[D33]] * %[[D26]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D38:.+]] = amdgpu.mfma %[[D33]] * %[[D35]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D39:.+]] = amdgpu.mfma %[[D31]] * %[[D26]] + %[[CST]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: amdgpu.lds_barrier + # CHECK: vector.store %[[D29]], %[[ALLOC]][%[[D5]], %[[D8]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<8xf16> + # CHECK: vector.store %[[D30]], %[[ALLOC_0]][%[[D5]], %[[D8]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<8xf16> + # CHECK: %[[D40:.+]]:8 = scf.for %[[ARG3:[a-zA-Z0-9_]+]] = %[[C0]] to %[[C2]] step %[[C1]] + # CHECK-SAME: iter_args(%[[ARG4:[a-zA-Z0-9_]+]] = %[[D22]], %[[ARG5:[a-zA-Z0-9_]+]] = %[[D34]], + # CHECK-SAME: %[[ARG6:[a-zA-Z0-9_]+]] = %[[D24]], %[[ARG7:[a-zA-Z0-9_]+]] = %[[D27]], %[[ARG8:[a-zA-Z0-9_]+]] = + # CHECK-SAME: %[[D36]], %[[ARG9:[a-zA-Z0-9_]+]] = %[[D37]], %[[ARG10:[a-zA-Z0-9_]+]] = %[[D38]], + # CHECK-SAME: %[[ARG11:[a-zA-Z0-9_]+]] = %[[D39]]) -> (vector<4xf16>, vector<4xf16>, vector<4xf16>, + # CHECK-SAME: vector<4xf16>, vector<4xf32>, vector<4xf32>, vector<4xf32>, vector<4xf32>) { + # CHECK: %[[D90:.+]] = amdgpu.mfma %[[ARG4]] * %[[ARG6]] + %[[ARG8]] {blocks = 1 : i32, k = 16 : i32, m = + # CHECK-SAME: 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D91:.+]] = amdgpu.mfma %[[ARG5]] * %[[ARG7]] + %[[ARG9]] {blocks = 1 : i32, k = 16 : i32, m = + # CHECK-SAME: 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D92:.+]] = amdgpu.mfma %[[ARG5]] * %[[ARG6]] + %[[ARG10]] {blocks = 1 : i32, k = 16 : i32, m = + # CHECK-SAME: 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D93:.+]] = amdgpu.mfma %[[ARG4]] * %[[ARG7]] + %[[ARG11]] {blocks = 1 : i32, k = 16 : i32, m = + # CHECK-SAME: 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: amdgpu.lds_barrier + # CHECK: %[[D94:.+]] = vector.load %[[ALLOC]][%[[D17]], %[[D21]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D95:.+]] = vector.load %[[ALLOC_0]][%[[D23]], %[[D21]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D96:.+]] = vector.load %[[ALLOC_0]][%[[D25]], %[[D20]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D97:.+]] = vector.load %[[ALLOC_0]][%[[D25]], %[[D21]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D98:.+]] = arith.muli %[[ARG3]], %[[C32]] : index + # CHECK: %[[D99:.+]] = arith.addi %[[D98]], %[[D8]] : index + # CHECK: %[[D100:.+]] = arith.addi %[[D99]], %[[C64]] : index + # CHECK: %[[D101:.+]] = vector.load %[[D0]][%[[D6]], %[[D100]]] : memref<128x128xf16, strided<[128, 1], + # CHECK-SAME: offset: ?>>, vector<8xf16> + # CHECK: %[[D102:.+]] = vector.load %[[D10]][%[[D12]], %[[D100]]] : memref<128x128xf16, strided<[128, 1], + # CHECK-SAME: offset: ?>>, vector<8xf16> + # CHECK: %[[D103:.+]] = vector.load %[[ALLOC]][%[[D17]], %[[D20]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D104:.+]] = vector.load %[[ALLOC]][%[[D32]], %[[D20]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D105:.+]] = vector.load %[[ALLOC]][%[[D32]], %[[D21]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D106:.+]] = vector.load %[[ALLOC_0]][%[[D23]], %[[D20]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D107:.+]] = amdgpu.mfma %[[D103]] * %[[D106]] + %[[D90]] {blocks = 1 : i32, k = 16 : i32, m = + # CHECK-SAME: 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D108:.+]] = amdgpu.mfma %[[D104]] * %[[D96]] + %[[D91]] {blocks = 1 : i32, k = 16 : i32, m = 16 + # CHECK-SAME: : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D109:.+]] = amdgpu.mfma %[[D104]] * %[[D106]] + %[[D92]] {blocks = 1 : i32, k = 16 : i32, m = + # CHECK-SAME: 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D110:.+]] = amdgpu.mfma %[[D103]] * %[[D96]] + %[[D93]] {blocks = 1 : i32, k = 16 : i32, m = 16 + # CHECK-SAME: : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: amdgpu.lds_barrier + # CHECK: vector.store %[[D101]], %[[ALLOC]][%[[D5]], %[[D8]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<8xf16> + # CHECK: vector.store %[[D102]], %[[ALLOC_0]][%[[D5]], %[[D8]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<8xf16> + # CHECK: scf.yield %[[D94]], %[[D105]], %[[D95]], %[[D97]], %[[D107]], %[[D108]], %[[D109]], %[[D110]] : + # CHECK-SAME: vector<4xf16>, vector<4xf16>, vector<4xf16>, vector<4xf16>, vector<4xf32>, vector<4xf32>, + # CHECK-SAME: vector<4xf32>, vector<4xf32> + # CHECK: } + # CHECK: %[[D41:.+]] = amdgpu.mfma %[[D40]]#[[D0:.+]] * %[[D40]]#[[D2:.+]] + %[[D40]]#[[D4:.+]] {blocks = 1 : + # CHECK-SAME: i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, + # CHECK-SAME: vector<4xf32> + # CHECK: %[[D42:.+]] = amdgpu.mfma %[[D40]]#[[D1:.+]] * %[[D40]]#[[D3:.+]] + %[[D40]]#[[D5:.+]] {blocks = 1 : + # CHECK-SAME: i32, k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, + # CHECK-SAME: vector<4xf32> + # CHECK: %[[D43:.+]] = amdgpu.mfma %[[D40]]#[[D1]] * %[[D40]]#[[D2]] + %[[D40]]#[[D6:.+]] {blocks = 1 : i32, + # CHECK-SAME: k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, + # CHECK-SAME: vector<4xf32> + # CHECK: %[[D44:.+]] = amdgpu.mfma %[[D40]]#[[D0]] * %[[D40]]#[[D3]] + %[[D40]]#[[D7:.+]] {blocks = 1 : i32, + # CHECK-SAME: k = 16 : i32, m = 16 : i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, + # CHECK-SAME: vector<4xf32> + # CHECK: amdgpu.lds_barrier + # CHECK: %[[D45:.+]] = vector.load %[[ALLOC]][%[[D17]], %[[D21]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D46:.+]] = vector.load %[[ALLOC_0]][%[[D23]], %[[D21]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D47:.+]] = vector.load %[[ALLOC_0]][%[[D25]], %[[D20]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D48:.+]] = vector.load %[[ALLOC_0]][%[[D25]], %[[D21]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D49:.+]] = vector.load %[[ALLOC]][%[[D17]], %[[D20]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D50:.+]] = vector.load %[[ALLOC]][%[[D32]], %[[D20]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D51:.+]] = vector.load %[[ALLOC]][%[[D32]], %[[D21]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D52:.+]] = vector.load %[[ALLOC_0]][%[[D23]], %[[D20]]] : memref<64x32xf16, + # CHECK-SAME: #[[GPU]].address_space>, vector<4xf16> + # CHECK: %[[D53:.+]] = amdgpu.mfma %[[D49]] * %[[D52]] + %[[D41]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D54:.+]] = amdgpu.mfma %[[D50]] * %[[D47]] + %[[D42]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D55:.+]] = amdgpu.mfma %[[D50]] * %[[D52]] + %[[D43]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D56:.+]] = amdgpu.mfma %[[D49]] * %[[D47]] + %[[D44]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D57:.+]] = amdgpu.mfma %[[D45]] * %[[D46]] + %[[D53]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D58:.+]] = amdgpu.mfma %[[D51]] * %[[D48]] + %[[D54]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D59:.+]] = amdgpu.mfma %[[D51]] * %[[D46]] + %[[D55]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D60:.+]] = amdgpu.mfma %[[D45]] * %[[D48]] + %[[D56]] {blocks = 1 : i32, k = 16 : i32, m = 16 : + # CHECK-SAME: i32, n = 16 : i32} blgp = none : vector<4xf16>, vector<4xf16>, vector<4xf32> + # CHECK: %[[D61:.+]] = vector.extract_strided_slice %[[D57]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: %[[D62:.+]] = stream.binding.subspan %[[ARG2]][%[[C0]]] : !stream.binding -> memref<128x128xf32, + # CHECK-SAME: strided<[128, 1], offset: ?>> + # CHECK: %[[D63:.+]] = arith.addi %[[D1]], %[[D15]] : index + # CHECK: %[[D64:.+]] = arith.addi %[[D63]], %[[D20]] : index + # CHECK: %[[D65:.+]] = arith.addi %[[D16]], %[[D11]] : index + # CHECK: %[[D66:.+]] = arith.addi %[[D65]], %[[D2]] : index + # CHECK: vector.store %[[D61]], %[[D62]][%[[D64]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D67:.+]] = vector.extract_strided_slice %[[D57]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: %[[D68:.+]] = arith.addi %[[D64]], %[[C1]] : index + # CHECK: vector.store %[[D67]], %[[D62]][%[[D68]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D69:.+]] = vector.extract_strided_slice %[[D57]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: %[[D70:.+]] = arith.addi %[[D64]], %[[C2]] : index + # CHECK: vector.store %[[D69]], %[[D62]][%[[D70]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D71:.+]] = vector.extract_strided_slice %[[D57]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: %[[D72:.+]] = arith.addi %[[D64]], %[[C3]] : index + # CHECK: vector.store %[[D71]], %[[D62]][%[[D72]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D73:.+]] = vector.extract_strided_slice %[[D58]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: %[[D74:.+]] = arith.addi %[[D64]], %[[C16]] : index + # CHECK: %[[D75:.+]] = arith.addi %[[D66]], %[[C16]] : index + # CHECK: vector.store %[[D73]], %[[D62]][%[[D74]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D76:.+]] = vector.extract_strided_slice %[[D58]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: %[[D77:.+]] = arith.addi %[[D64]], %[[C17]] : index + # CHECK: vector.store %[[D76]], %[[D62]][%[[D77]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D78:.+]] = vector.extract_strided_slice %[[D58]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: %[[D79:.+]] = arith.addi %[[D64]], %[[C18]] : index + # CHECK: vector.store %[[D78]], %[[D62]][%[[D79]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D80:.+]] = vector.extract_strided_slice %[[D58]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: %[[D81:.+]] = arith.addi %[[D64]], %[[C19]] : index + # CHECK: vector.store %[[D80]], %[[D62]][%[[D81]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D82:.+]] = vector.extract_strided_slice %[[D59]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: vector.store %[[D82]], %[[D62]][%[[D74]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D83:.+]] = vector.extract_strided_slice %[[D59]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: vector.store %[[D83]], %[[D62]][%[[D77]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D84:.+]] = vector.extract_strided_slice %[[D59]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: vector.store %[[D84]], %[[D62]][%[[D79]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D85:.+]] = vector.extract_strided_slice %[[D59]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: vector.store %[[D85]], %[[D62]][%[[D81]], %[[D66]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D86:.+]] = vector.extract_strided_slice %[[D60]] {offsets = [0], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: vector.store %[[D86]], %[[D62]][%[[D64]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D87:.+]] = vector.extract_strided_slice %[[D60]] {offsets = [1], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: vector.store %[[D87]], %[[D62]][%[[D68]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D88:.+]] = vector.extract_strided_slice %[[D60]] {offsets = [2], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: vector.store %[[D88]], %[[D62]][%[[D70]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: %[[D89:.+]] = vector.extract_strided_slice %[[D60]] {offsets = [3], sizes = [1], strides = [1]} : + # CHECK-SAME: vector<4xf32> to vector<1xf32> + # CHECK: vector.store %[[D89]], %[[D62]][%[[D72]], %[[D75]]] : memref<128x128xf32, strided<[128, 1], offset: + # CHECK-SAME: ?>>, vector<1xf32> + # CHECK: return + + @run_test def test_add_float(): constraints: list[tkw.Constraint] = [ diff --git a/lit_tests/kernel/wave/scheduling.py b/lit_tests/kernel/wave/scheduling.py new file mode 100644 index 000000000..eafabb272 --- /dev/null +++ b/lit_tests/kernel/wave/scheduling.py @@ -0,0 +1,227 @@ +# RUN: python %s | FileCheck %s + +import logging +import unittest +import shark_turbine.kernel as tk +import shark_turbine.kernel.lang as tkl +import shark_turbine.kernel.wave as tkw +from shark_turbine.kernel.wave.promotion import promote_placeholders +from shark_turbine.kernel.wave.hoisting import hoist_allocs +from shark_turbine.kernel.wave.expansion import expand_graph +from shark_turbine.kernel.lang.global_symbols import * +from shark_turbine.kernel._support.tracing import CapturedTrace +from shark_turbine.kernel._support.indexing import IndexingContext +from shark_turbine.kernel.ops.wave_ops import * +from shark_turbine.kernel.wave.utils import run_test, print_subgraph +from shark_turbine.kernel.wave.minimize_global_loads import minimize_global_loads +from shark_turbine.kernel.wave.shared_memory_indexing import ( + apply_shared_memory_indexing_corrections, +) +from shark_turbine.kernel.wave.scheduling.schedule import schedule_graph + + +# Input sizes +M = tkl.sym.M +N = tkl.sym.N +K = tkl.sym.K + +# Workgroup tile sizes +BLOCK_M = tkl.sym.BLOCK_M +BLOCK_N = tkl.sym.BLOCK_N +BLOCK_K = tkl.sym.BLOCK_K + +# Address space +ADDRESS_SPACE = tkl.sym.ADDRESS_SPACE +ADDRESS_SPACE_0 = tkl.sym.ADDRESS_SPACE_0 + +# Induction variable for dimension K +ARGK = tkl.sym.ARGK + + +@tkw.wave_trace_only() +def gemm_pipelined( + a: tkl.Memory[M, K, ADDRESS_SPACE_0, tkl.f16], + b: tkl.Memory[N, K, ADDRESS_SPACE_0, tkl.f16], + c: tkl.Memory[M, N, ADDRESS_SPACE, tkl.f32], +): + c_reg = tkl.Register[M, N, tkl.f32](0.0) + + @tkw.reduction(K, init_args=[c_reg]) + def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: + a_reg = tkw.read(a, elements_per_thread=4) + b_reg = tkw.read(b, elements_per_thread=4) + acc = tkw.mma(a_reg, b_reg, acc) + return acc + + tkw.write(repeat, c, elements_per_thread=4) + + +@run_test +def test_gemm_pipelined(): + constraints: list[tkw.Constraint] = [tkw.WorkgroupConstraint(M, BLOCK_M, 0)] + constraints += [tkw.WorkgroupConstraint(N, BLOCK_N, 1)] + constraints += [tkw.WaveConstraint(M, BLOCK_M / 2, 0)] + constraints += [tkw.WaveConstraint(N, BLOCK_N / 2, 1)] + constraints += [tkw.TilingConstraint(K, BLOCK_K, ARGK)] + constraints += [ + tkw.HardwareConstraint(threads_per_wave=64, waves_per_block=(2, 2, 1)) + ] + with tk.gen.TestLaunchContext( + { + M: 128, + N: 256, + K: 128, + BLOCK_M: 64, + BLOCK_N: 64, + BLOCK_K: 32, + ADDRESS_SPACE: GLOBAL_ADDRESS_SPACE, + ADDRESS_SPACE_0: SHARED_ADDRESS_SPACE, + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 2, + GLOBAL_MEMORY_UNITS: 2, + MMA_UNITS: 2, + } + ): + trace: CapturedTrace = gemm_pipelined() + IndexingContext.current().finalize() + promote_placeholders(trace, constraints) + hoist_allocs(trace) + expand_graph(trace, constraints) + minimize_global_loads(trace, constraints) + apply_shared_memory_indexing_corrections(trace, constraints) + schedule_graph(trace, constraints) + + print_subgraph(trace, "pipelined_reduction", False) + # CHECK: %acc_0_0_0 + # CHECK-NEXT: %acc_0_1_0 + # CHECK-NEXT: %acc_1_0_0 + # CHECK-NEXT: %acc_1_1_0 + # CHECK-NEXT: %rotating_reg_0 + # CHECK-NEXT: %rotating_reg_1 + # CHECK-NEXT: %rotating_reg_2 + # CHECK-NEXT: %rotating_reg_3 + # CHECK-NEXT: %rotating_reg_4 + # CHECK-NEXT: %rotating_reg_5 + # CHECK-NEXT: %rotating_reg_6 + # CHECK-NEXT: %mma_1_1_1 + # CHECK-SAME: (%rotating_reg_1, %rotating_reg_4, %rotating_reg_6) + # CHECK-NEXT: %read_shared_0_0_0 + # CHECK-NEXT: %read_shared_0_0_1 + # CHECK-NEXT: %read_4 + # CHECK-NEXT: %read_5 + # CHECK-NEXT: %read_shared_1_0_0 + # CHECK-NEXT: %read_shared_1_0_1 + # CHECK-NEXT: %mma_0_0_0 + # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_1, %acc_0_0_0) + # CHECK-NEXT: %mma_0_1_0 + # CHECK-SAME: (%read_shared_0_0_0, %rotating_reg_3, %acc_0_1_0) + # CHECK-NEXT: %mma_0_0_1 + # CHECK-SAME: (%rotating_reg_0, %rotating_reg_2, %mma_0_0_0) + # CHECK-NEXT: %mma_1_0_0 + # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_1, %acc_1_0_0) + # CHECK-NEXT: %write_2 + # CHECK-NEXT: %write_3 + # CHECK-NEXT: %mma_1_0_1 + # CHECK-SAME: (%read_shared_1_0_1, %rotating_reg_2, %mma_1_0_0) + # CHECK-NEXT: %mma_0_1_1 + # CHECK-SAME: (%rotating_reg_0, %rotating_reg_5, %mma_0_1_0) + # CHECK-NEXT: %read_shared_0_1_0 + # CHECK-NEXT: %read_shared_0_1_1 + # CHECK-NEXT: %mma_1_1_0 + # CHECK-SAME: (%read_shared_1_0_0, %rotating_reg_3, %mma_1_1_1) + # CHECK-NEXT: %read_shared_0_0_2 + # CHECK-NEXT: %read_shared_0_0_3 + # CHECK-NEXT: [mma_0_0_1, mma_0_1_1, mma_1_0_1, mma_1_1_1, read_shared_0_0_2, read_shared_1_0_1, read_shared_0_0_3, read_shared_0_1_0, rotating_reg_5, read_shared_0_1_1, mma_1_1_0] + + print_subgraph(trace, "region_1", False) + # CHECK: %a + # CHECK-NEXT: %b + # CHECK-NEXT: %c + # CHECK-NEXT: %register_0_0_0 + # CHECK-NEXT: %register_1_1_0 + # CHECK-NEXT: %register_1_0_0 + # CHECK-NEXT: %register_0_1_0 + # CHECK-NEXT: %allocate + # CHECK-NEXT: %allocate_1 + # CHECK-NEXT: %read_4 + # CHECK-NEXT: %read_5 + # CHECK-NEXT: %write_2 + # CHECK-NEXT: %write_3 + # CHECK-NEXT: %read_shared_0_1_0 + # CHECK-NEXT: %read_shared_0_1_1 + # CHECK-NEXT: %read_shared_0_0_1 + # CHECK-NEXT: %read_shared_0_0_2 + # CHECK-NEXT: %read_shared_0_0_0 + # CHECK-NEXT: %read_shared_0_0_3 + # CHECK-NEXT: %read_6 + # CHECK-NEXT: %read_7 + # CHECK-NEXT: %read_shared_1_0_0 + # CHECK-NEXT: %read_shared_1_0_1 + # CHECK-NEXT: %mma_0_0_0 + # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_0_3, %register_0_0_0) + # CHECK-NEXT: %mma_0_1_0 + # CHECK-SAME: (%read_shared_0_0_0, %read_shared_0_1_0, %register_0_1_0) + # CHECK-NEXT: %mma_0_0_1 + # CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_0_2, %mma_0_0_0) + # CHECK-NEXT: %mma_1_0_0 + # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_0_3, %register_1_0_0) + # CHECK-NEXT: %write_4 + # CHECK-NEXT: %write_5 + # CHECK-NEXT: %mma_1_0_1 + # CHECK-SAME: (%read_shared_1_0_1, %read_shared_0_0_2, %mma_1_0_0) + # CHECK-NEXT: %mma_0_1_1 + # CHECK-SAME: (%read_shared_0_0_1, %read_shared_0_1_1, %mma_0_1_0) + # CHECK-NEXT: %read_shared_0_1_2 + # CHECK-NEXT: %read_shared_0_1_3 + # CHECK-NEXT: %mma_1_1_0 + # CHECK-SAME: (%read_shared_1_0_0, %read_shared_0_1_0, %register_1_1_0) + # CHECK-NEXT: %read_shared_0_0_4 + # CHECK-NEXT: %read_shared_0_0_5 + # CHECK-NEXT: %reduction_1 + # CHECK-NEXT: %getresult_1_1_0 + # CHECK-NEXT: %getresult_1_0_0 + # CHECK-NEXT: %getresult_0_1_0 + # CHECK-NEXT: %getresult_0_0_0 + # CHECK-NEXT: %get_result_4 + # CHECK-NEXT: %get_result_5 + # CHECK-NEXT: %get_result_6 + # CHECK-NEXT: %get_result_7 + # CHECK-NEXT: %get_result_8 + # CHECK-NEXT: %get_result_9 + # CHECK-NEXT: %get_result_10 + # CHECK-NEXT: %mma_1_1_1 + # CHECK-SAME: (%get_result_5, %get_result_9, %get_result_10) + # CHECK-NEXT: %read_shared_0_0_6 + # CHECK-NEXT: %read_shared_0_0_7 + # CHECK-NEXT: %read_shared_1_0_2 + # CHECK-NEXT: %read_shared_1_0_3 + # CHECK-NEXT: %mma_0_0_2 + # CHECK-SAME: (%read_shared_0_0_6, %read_shared_0_0_7, %getresult_0_0_0) + # CHECK-NEXT: %mma_0_1_2 + # CHECK-SAME: (%read_shared_0_0_6, %get_result_7, %getresult_0_1_0) + # CHECK-NEXT: %mma_0_0_3 + # CHECK-SAME: (%get_result_4, %get_result_6, %mma_0_0_2) + # CHECK-NEXT: %mma_1_0_2 + # CHECK-SAME: (%read_shared_1_0_2, %read_shared_0_0_7, %getresult_1_0_0) + # CHECK-NEXT: %mma_1_0_3 + # CHECK-SAME: (%read_shared_1_0_3, %get_result_6, %mma_1_0_2) + # CHECK-NEXT: %mma_0_1_3 + # CHECK-SAME: (%get_result_4, %get_result_9, %mma_0_1_2) + # CHECK-NEXT: %mma_1_1_2 + # CHECK-SAME: (%read_shared_1_0_2, %get_result_7, %mma_1_1_1) + # CHECK-NEXT: %mma_1_1_3 + # CHECK-SAME: (%read_shared_1_0_3, %get_result_9, %mma_1_1_2) + # CHECK-NEXT: %write_0_0_0 + # CHECK-NEXT: %write_1_1_0 + # CHECK-NEXT: %write_1_0_0 + # CHECK-NEXT: %write_0_1_0 + # CHECK-NEXT: return None + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/shark_turbine/kernel/_support/tracing.py b/shark_turbine/kernel/_support/tracing.py index 424242577..857cdb345 100644 --- a/shark_turbine/kernel/_support/tracing.py +++ b/shark_turbine/kernel/_support/tracing.py @@ -129,6 +129,9 @@ def __init__(self, region_graph: RegionGraph, root_graph: str): def get_subgraph(self, name: str) -> fx.Graph: return self.region_graph.subgraphs[name] + def add_subgraph(self, name: str, graph: fx.Graph): + self.region_graph.subgraphs[name] = graph + def get_root_graph(self) -> fx.Graph: return self.get_subgraph(self.root_graph) diff --git a/shark_turbine/kernel/ops/wave_ops.py b/shark_turbine/kernel/ops/wave_ops.py index ebadf0c4e..b3e6dbca4 100644 --- a/shark_turbine/kernel/ops/wave_ops.py +++ b/shark_turbine/kernel/ops/wave_ops.py @@ -27,6 +27,7 @@ if TYPE_CHECKING: from ..wave.constraints import Constraint + from ..wave.scheduling.resources import Operation T = TypeVar("T", bound=Type[Any]) AccT = TypeVar("AccT") @@ -453,6 +454,8 @@ def index(self, value: Any): self.fx_node.index = {} for dim, key in value.items(): self.fx_node.index[dim] = key + elif isinstance(value, list): + self.fx_node.index = list(value) else: raise ValueError("Index must be a dict") @@ -692,11 +695,38 @@ def is_barrier_between(self, src: fx.Node, dst: fx.Node) -> bool: prev_node, found_src = prev_node.prev, prev_node == src if not found_src: return False - while next_node and not found_dst: + while next_node.next.op != "root" and not found_dst: next_node, found_dst = next_node.next, next_node == dst return found_dst +@define_op("scheduling_barrier") +@dataclass +class SchedulingBarrier(CustomOp): + """ + Represents a scheduling barrier in the graph. + Takes in a list of operations that are allowed to cross + the barrier. + """ + + operations: list[Operation] + + +@define_op("scheduling_group_barrier") +@dataclass +class SchedulingGroupBarrier(CustomOp): + """ + Represents a scheduling group barrier in the graph. + The scheduling group barrier defines scheduling groups. + Each scheduling group contains differing numbers of instructions + that are captured in the instruction counts. + The sync_id identifies scheduling groups that need to be aware of each other. + """ + + instruction_counts: dict[Operation, int] + sync_id: int + + @define_op("register") @dataclass class NewRegister(CustomOp): @@ -921,6 +951,10 @@ def index(self) -> list[dict[IndexSymbol, IndexSequence]]: else None ) + @index.setter + def index(self, value: Any): + CustomOp.index.fset(self, value) + @define_op("write") @dataclass diff --git a/shark_turbine/kernel/wave/codegen.py b/shark_turbine/kernel/wave/codegen.py index e4a8cf72c..9e74366fe 100644 --- a/shark_turbine/kernel/wave/codegen.py +++ b/shark_turbine/kernel/wave/codegen.py @@ -37,6 +37,7 @@ stream_d, scf_d, vector_d, + llvm_d, ) from shark_turbine.aot.support.ir_utils import _is_float_type, _is_integer_like_type @@ -57,6 +58,8 @@ shared_memory_barrier, extract_slice, CustomOp, + scheduling_barrier, + scheduling_group_barrier, ) from ..lang.wave_types import IndexMapping, IndexSymbol from ..compiler.base import CodegenError, ValidationError, NDEBUG @@ -77,7 +80,7 @@ WorkgroupConstraint, TilingConstraint, ) -from .utils import subs_idxc +from .utils import subs_idxc, get_scheduling_mask # Indexing imports. from .._support.indexing import IndexingContext, IndexExpr, IndexSequence @@ -90,6 +93,7 @@ class WaveEmitter: root_sig: BoundKernelSignature trace: CapturedTrace constraints: list[Constraint] + scheduling_metadata: dict[fx.Node, int] ip: InsertionPoint = None OP_HANDLERS: ClassVar[dict[str, Callable[["WaveEmitter", fx.Node], None]]] = {} _node_values: ClassVar[dict[fx.Node, List[IRProxyValue]]] = {} @@ -209,13 +213,14 @@ def _get_div(mul, add, denominator): induction_var_syms = [] induction_vars = [] - for constraint in emitter.constraints: - if isinstance(constraint, TilingConstraint): - assert ( - constraint.dim in emitter.induction_vars - ), f"Could not find induction var for {constraint.dim} dimension" - induction_var_syms.append(constraint.induction_var) - induction_vars.append(emitter.induction_vars[constraint.dim]) + if emitter.induction_vars: + for constraint in emitter.constraints: + if isinstance(constraint, TilingConstraint): + assert ( + constraint.dim in emitter.induction_vars + ), f"Could not find induction var for {constraint.dim} dimension" + induction_var_syms.append(constraint.induction_var) + induction_vars.append(emitter.induction_vars[constraint.dim]) # TODO: factor this out all_symbols = emitter.thread_ids + emitter.workgroup_ids + induction_vars @@ -910,7 +915,6 @@ def handle_reduction(emitter: WaveEmitter, node: fx.Node): flat_init_args, _ = pytree.tree_flatten((init_args)) flat_init_args = [cast_py_value(emitter, arg) for arg in flat_init_args] - # Without scheduling, we assume that we always start at 0. start = arith_d.constant(IndexType.get(), int(0)) count = None @@ -921,7 +925,10 @@ def handle_reduction(emitter: WaveEmitter, node: fx.Node): # For now, we assume that dimensions that have tiling constraints on them, # do not have any other constraints. - end = arith_d.constant(IndexType.get(), int(count)) + end_value = int(count) + if node in emitter.scheduling_metadata: + end_value = emitter.scheduling_metadata[node] + end = arith_d.constant(IndexType.get(), end_value) # Since we divide the end by the tile size, we need to make sure that the # step is 1. @@ -970,6 +977,38 @@ def handle_shared_memory_barrier(emitter: WaveEmitter, node: fx.Node): amdgpu_d.lds_barrier() +@handle_op(scheduling_barrier) +def handle_scheduling_barrier(emitter: WaveEmitter, node: fx.Node): + try: + operations = node.args[0] + except ValueError as e: + raise ValidationError("Malformed arguments") from e + mask = 1 + for operation in operations: + mask |= get_scheduling_mask(operation) + + mask = arith_d.constant(IntegerType.get_signless(32), mask) + llvm_d.call_intrinsic(None, "llvm.amdgcn.sched.barrier", [mask]) + + +@handle_op(scheduling_group_barrier) +def handle_scheduling_group_barrier(emitter: WaveEmitter, node: fx.Node): + try: + instruction_counts, sync_id = node.args + except ValueError as e: + raise ValidationError("Malformed arguments") from e + for operation, counts in instruction_counts.items(): + if get_scheduling_mask(operation) is None: + continue + mask = arith_d.constant( + IntegerType.get_signless(32), get_scheduling_mask(operation) + ) + counts = arith_d.constant(IntegerType.get_signless(32), counts) + llvm_d.call_intrinsic( + None, "llvm.amdgcn.sched.group.barrier", [mask, counts, sync_id] + ) + + ############################################################################### # Slicing ops ############################################################################### diff --git a/shark_turbine/kernel/wave/scheduling/graph_utils.py b/shark_turbine/kernel/wave/scheduling/graph_utils.py index e625b6663..af398af34 100644 --- a/shark_turbine/kernel/wave/scheduling/graph_utils.py +++ b/shark_turbine/kernel/wave/scheduling/graph_utils.py @@ -213,12 +213,13 @@ def topological_sort_nodes( Perform a topological sort on the nodes in the strongly connected component that have an edge in edges, excluding certain nodes. """ - scc_nodes = set(scc) - set(exclude) + scc_nodes = set(scc) filtered_nodes = set() for edge in edges: if edge._from in scc_nodes and edge._to in scc_nodes: filtered_nodes.add(edge._to) filtered_nodes.add(edge._from) + filtered_nodes -= set(exclude) if exclude is not None else set() sorted_nodes = sorted(filtered_nodes, key=lambda x: x.f) return sorted_nodes diff --git a/shark_turbine/kernel/wave/scheduling/loop_reconstruction.py b/shark_turbine/kernel/wave/scheduling/loop_reconstruction.py new file mode 100644 index 000000000..52f205b12 --- /dev/null +++ b/shark_turbine/kernel/wave/scheduling/loop_reconstruction.py @@ -0,0 +1,556 @@ +from ..constraints import Constraint, TilingConstraint +from ..._support.indexing import IndexSymbol +from ..._support.tracing import CapturedTrace +from ...ops.wave_ops import ( + Reduction, + IterArg, + Placeholder, + Allocate, + Output, + Write, + GetResult, + get_custom, +) +from .modulo_scheduling import ModuloScheduler +from ..utils import ( + graph_copy, + erase_graph, + get_induction_variable, + replace_uses_in, +) +from ..utils import subs_idxc +import torch.fx as fx +import math +from collections import deque +from ..visualization import visualize_mapped_graphs, visualize_graph +from ....support.logging import get_logger +from ...lang.global_symbols import SHARED_ADDRESS_SPACE +import random +from typing import Optional +from .loop_reconstruction_utils import ( + ArgumentContext, + create_fill_stage_schedule, + create_drain_stage_schedule, + liveness_analysis, + partition_graph_by_stage, + interleave_instructions, +) +from enum import Enum + +logger = get_logger("turbine.wave.scheduling.loop_reconstruction") + + +class PipelineStage(Enum): + PROLOGUE = 0 + KERNEL = 1 + EPILOGUE = 2 + + +def add_nodes_by_schedule( + reduction_graph: fx.Graph, + partitioned_graph: list[dict[int, fx.Node]], + arg_context: ArgumentContext, + stages: list[int], + initiation_interval: int, + induction_variable: IndexSymbol, + current_induction_variables: list[int], + rotating_registers: dict[fx.Node, list[fx.Node]], + pipelining_stage: PipelineStage = PipelineStage.KERNEL, +): + """ + Interleave the instructions in the partitioned graph by stage + for a single initiation interval, updating the argument maps + per stage starting at the provided start times and indices. + """ + fill_or_drain = pipelining_stage in [PipelineStage.PROLOGUE, PipelineStage.EPILOGUE] + fill = pipelining_stage == PipelineStage.PROLOGUE + drain = pipelining_stage == PipelineStage.EPILOGUE + + for cycle in range(initiation_interval): + logger.debug(f"Cycle: {cycle}") + # Interleave the instructions that are scheduled at the same cycle. + interleaved_instructions = [] + for iteration, stage in enumerate(stages): + if stage is None: + continue + if cycle not in partitioned_graph[stage]: + continue + for node in partitioned_graph[stage][cycle]: + interleaved_instructions.append((iteration, stage, node)) + interleave_instructions(interleaved_instructions) + + for iteration, stage, node in interleaved_instructions: + logger.debug(f"Node: {node}, Stage: {stage}, Iteration: {iteration}") + custom_node = get_custom(node) + logger.debug(f"Node args: {node.args}") + for arg in node.args: + if arg_context.contains_in_iteration(iteration, arg): + logger.debug( + f"Found arg: {arg} in partitioned argument map. Using {arg_context.get_from_iteration(iteration, arg)}." + ) + continue + new_node = custom_node.copy( + new_graph=reduction_graph, + arg_transform=lambda x: ( + arg_context.get_from_iteration(iteration, x) + if arg_context.contains_in_iteration(iteration, x) + else x + ), + ) + # Update the argument context. + arg_context[(iteration, stage, node)] = new_node.fx_node + logger.debug( + f"Copying Node: {node}, Stage: {stage}, Iteration: {iteration} -> {new_node.fx_node}" + ) + # Set the index for the new node by substituting the induction variable + # for the current iteration. + new_node.index = node.index + for dim in new_node.index: + new_node.index[dim] = new_node.index[dim].subs( + {induction_variable: current_induction_variables[iteration]} + ) + # Add scheduling parameters for debugging. + new_node.scheduling_parameters = node.scheduling_parameters + # Update the rotating registers and argument context for the current node (if applicable). + if node in rotating_registers: + rotating_registers[node].append(new_node.fx_node) + rotating_registers[node].popleft() + # If draining, then override the rotating registers and update the argument context. + if fill_or_drain: + for next_stage in range(stage + 1, len(stages)): + arg_context[(iteration, next_stage, node)] = new_node.fx_node + + # Update the init args in the argument context whenever a result is computed. + if node in arg_context.results: + if ( + pipelining_stage == PipelineStage.KERNEL + or pipelining_stage == PipelineStage.EPILOGUE + ): + logger.debug( + f"Updating result: {node} -> {arg_context.result_to_iter_arg[node]} to {new_node.fx_node}." + ) + arg_context.map_arg_all( + arg_context.result_to_iter_arg[node], new_node.fx_node + ) + if pipelining_stage == PipelineStage.PROLOGUE: + logger.debug( + f"Updating result: {node} -> {arg_context.result_to_init_arg[node]} to {new_node.fx_node}." + ) + arg_context.map_arg_all( + arg_context.result_to_init_arg[node], new_node.fx_node + ) + + +def push_placeholders( + implicit_captures: list[fx.Node], + reduction_subgraph: fx.Node, + arg_context: ArgumentContext, +): + """ + Push placeholders into the argument context for the reduction graph. + """ + for node in reduction_subgraph.nodes: + custom = get_custom(node) + if isinstance(custom, Placeholder) and not isinstance(custom, IterArg): + root_node = [x for x in implicit_captures if x.name == node.name][0] + assert root_node is not None + arg_context.map_arg_all(node, root_node) + + +def construct_prologue( + reduction_subgraph: fx.Graph, + reduction: Reduction, + partitioned_graph: list[dict[int, fx.Node]], + scheduler: ModuloScheduler, + rotating_registers: dict[fx.Node, list[fx.Node]], + induction_variable: IndexSymbol, + new_induction_variables: list[int], + stages: list[int], +): + """ + Construct the prologue of the pipelined loop. + For this, we need to copy nodes from the reduction_graph and insert them + before the reduction operator in the root graph in the appropriate order. + We also need to initialize the rotating registers and update the indices + of the nodes to use the appropriate values of the induction variable. + """ + logger.debug("=====================================") + logger.debug("Constructing prologue.") + logger.debug("=====================================") + + arg_context = ArgumentContext( + reduction.outputs(reduction_subgraph), + reduction.iter_args(reduction_subgraph), + reduction.init_args, + scheduler.num_stages, + ) + + # Map iter args to init args in the prologue. + for iter_arg, init_arg in zip( + reduction.iter_args(reduction_subgraph), reduction.init_args + ): + arg_context.map_arg_all(iter_arg, init_arg) + + push_placeholders(reduction.implicit_captures, reduction_subgraph, arg_context) + with reduction.graph.inserting_before(reduction.fx_node): + for i in range(scheduler.num_stages - 1): + add_nodes_by_schedule( + reduction.graph, + partitioned_graph, + arg_context, + stages[i], + scheduler.initiation_interval, + induction_variable, + new_induction_variables, + rotating_registers, + PipelineStage.PROLOGUE, + ) + + # During the prologue, we may have computed results that need to be passed as init args + # to the kernel. + new_init_args: list[fx.Node] = [] + for init_arg in reduction.init_args: + mapped_init_arg = arg_context.lookup(init_arg) + if mapped_init_arg is None: + mapped_init_arg = init_arg + new_init_args.append(mapped_init_arg) + reduction.init_args = new_init_args + + +def flatten_dict_values( + rotating_registers: dict[fx.Node, list[fx.Node]] +) -> list[fx.Node]: + """ + Flatten the values of the rotating registers into a list. + """ + return [ + register for registers in rotating_registers.values() for register in registers + ] + + +def unflatten_dict_values( + rotating_registers_shapes: dict[fx.Node, int], values: list[fx.Node] +) -> dict[fx.Node, list[fx.Node]]: + """ + Unflatten the values of the rotating registers into a dictionary + using the provided shapes. + """ + rotating_registers = {} + count = 0 + for node, shape in rotating_registers_shapes.items(): + rotating_registers[node] = deque(values[count : count + shape]) + count += shape + assert count == sum(rotating_registers_shapes.values()) + return rotating_registers + + +def push_rotating_registers( + arg_context: ArgumentContext, + rotating_registers: dict[fx.Node, list[fx.Node]], + graph: fx.Graph, + node_map: dict[fx.Node, fx.Node], + create_new_nodes: bool = False, +) -> dict[fx.Node, deque[fx.Node]]: + """ + Pushes the rotating registers into the argument map + at the appropriate stages. Create new nodes in the + specified graph if requested. + + For each rotating register, + we evaluate which stage it belongs to and update the argument + context for the next stage and n - 1 stages after it, where + n is the total number of rotating registers. + If var a has [a, b, c] as rotating registers, then in a 3-stage schedule + a is used in stage 2, (iteration 0) + b in stage 1, (iteration 1) + c in stage 0. (iteration 2) + """ + new_rotating_registers: dict[fx.Node, deque[fx.Node]] = {} + count = 0 + for node, registers in rotating_registers.items(): + new_registers: deque[fx.Node] = deque() + custom = get_custom(node) + stage = custom.scheduling_parameters["stage"] + iteration = arg_context.get_kernel_iteration(stage) + arg_context[(iteration, stage, node)] = registers[-1] + for i, register in enumerate(registers): + mapped_stage = stage + len(registers) - i + mapped_iteration = arg_context.get_kernel_iteration(mapped_stage) + if create_new_nodes: + iter_arg = IterArg(f"rotating_reg_{count}").add_to_graph(graph) + iter_arg.type = get_custom(node).type + iter_arg.index = get_custom(node).index + arg_context[(mapped_iteration, mapped_stage, node)] = iter_arg + new_registers.append(iter_arg) + logger.debug( + f"Mapped orig: {node_map[node]} / mapped: {iter_arg} to stage {mapped_stage}." + ) + else: + arg_context[(mapped_iteration, mapped_stage, node)] = register + logger.debug( + f"Mapped orig: {node_map[node]} / mapped: {register} to stage {mapped_stage}." + ) + count += 1 + if new_registers: + new_rotating_registers[node] = new_registers + return new_rotating_registers + + +def construct_kernel( + reduction_subgraph: fx.Graph, + reduction: Reduction, + partitioned_graph: list[dict[int, fx.Node]], + scheduler: ModuloScheduler, + rotating_registers: dict[fx.Node, list[fx.Node]], + induction_variable: IndexSymbol, + new_induction_variables: list[int], + node_map: dict[fx.Node, fx.Node], + visualize: bool = False, +) -> tuple[Reduction, fx.Graph]: + """ + Construct the kernel of the pipelined loop. + First, we construct a new reduction op with an empty graph. + Then, we set the init args, construct the iter args and add the ops. + Finally, we create the output node with the return values. + The iter args/results of the pipelined reduction are always: + [results0, result1, ..., resultN, rotating_reg0, rotating_reg1, ..., rotating_regN] + """ + logger.debug("=====================================") + logger.debug("Constructing kernel.") + logger.debug("=====================================") + + with reduction.graph.inserting_before(reduction.fx_node): + pipelined_reduction = Reduction( + reduction.axis, + init_args=reduction.init_args + flatten_dict_values(rotating_registers), + subgraph_name="pipelined_reduction", + implicit_captures=reduction.implicit_captures, + ).add_to_graph(reduction.graph) + pipelined_reduction.index = reduction.index + pipelined_reduction_graph = fx.Graph() + reduction.graph.subgraphs["pipelined_reduction"] = pipelined_reduction_graph + + # Update the argument map for the new reduction. + arg_context = ArgumentContext( + reduction.outputs(reduction_subgraph), + reduction.iter_args(reduction_subgraph), + reduction.init_args, + scheduler.num_stages, + ) + push_placeholders(reduction.implicit_captures, reduction_subgraph, arg_context) + + # For the original iter args, we just map the old ones to the new ones. + # Do this for all stages, since the original iter args are "dummy" nodes + # during scheduling. + for node in arg_context.iter_args: + iter_arg = IterArg(node.name).add_to_graph(pipelined_reduction_graph) + iter_arg.type = get_custom(node).type + iter_arg.index = get_custom(node).index + arg_context.map_arg_all(node, iter_arg) + + # Push the rotating registers into the argument context. + new_rotating_registers: dict[fx.Node, deque[fx.Node]] = push_rotating_registers( + arg_context, + rotating_registers, + pipelined_reduction_graph, + node_map, + create_new_nodes=True, + ) + + add_nodes_by_schedule( + pipelined_reduction_graph, + partitioned_graph, + arg_context, + list(reversed(range(scheduler.num_stages))), + scheduler.initiation_interval, + induction_variable, + new_induction_variables, + new_rotating_registers, + PipelineStage.KERNEL, + ) + + # Create output node (last node in the graph). + return_vals: list[fx.Node] = arg_context.get_kernel_results() + for registers in new_rotating_registers.values(): + return_vals.extend(registers) + + Output(return_vals).add_to_graph(pipelined_reduction_graph) + reduction.replace_all_uses_with(pipelined_reduction) + + if visualize: + visualize_mapped_graphs( + pipelined_reduction_graph, + new_rotating_registers, + arg_context.argument_map, + "kernel.png", + ) + + return pipelined_reduction, pipelined_reduction_graph + + +def construct_epilogue( + reduction_subgraph: fx.Graph, + reduction: Reduction, + pipelined_reduction: Reduction, + partitioned_graph: list[dict[int, fx.Node]], + scheduler: ModuloScheduler, + rotating_registers: dict[fx.Node, list[fx.Node]], + induction_variable: IndexSymbol, + new_induction_variables: list[int], + stages: list[int], + num_rotating_registers: dict[fx.Node, int], + node_map: dict[fx.Node, fx.Node], + visualize: bool = False, +): + """ + Construct the epilogue of the pipelined loop. + The difference from the prologue is that we need to map the results + of the pipelined reduction to the remaining stages. (In the prologue, + no iteration is every completed and so we don't compute the final results) + We emit GetResult nodes for the rotating registers and map them to + the different epilogue stages. + """ + logger.debug("=====================================") + logger.debug("Constructing epilogue.") + logger.debug("=====================================") + + arg_context = ArgumentContext( + reduction.outputs(reduction_subgraph), + reduction.iter_args(reduction_subgraph), + reduction.init_args, + scheduler.num_stages, + ) + + existing_get_results: list[GetResult] = sorted( + [x for x in pipelined_reduction.users if isinstance(x, GetResult)], + key=lambda x: x.res_idx, + ) + existing_users = {x: x.users for x in existing_get_results} + + # Map the results from the kernel to the init args (for stages). + for iter_arg, get_result in zip( + reduction.iter_args(reduction_subgraph), existing_get_results + ): + arg_context.map_arg_all(iter_arg, get_result.fx_node) + + with pipelined_reduction.graph.inserting_before( + existing_get_results[0].fx_node.next + ): + # Add get result nodes for the rotating registers and update the + # argument map with them. + rotating_registers_get_results = [] + offset = len(existing_get_results) + for i in range(len(flatten_dict_values(rotating_registers))): + rotating_registers_get_results.append( + GetResult(pipelined_reduction.fx_node, i + offset).add_to_graph( + pipelined_reduction.graph + ) + ) + rotating_registers = unflatten_dict_values( + num_rotating_registers, rotating_registers_get_results + ) + + # Push the rotating registers onto the argument map. + push_rotating_registers(arg_context, rotating_registers, None, node_map, False) + push_placeholders(reduction.implicit_captures, reduction_subgraph, arg_context) + + for i in range(scheduler.num_stages - 1): + add_nodes_by_schedule( + pipelined_reduction.graph, + partitioned_graph, + arg_context, + stages[i], + scheduler.initiation_interval, + induction_variable, + new_induction_variables, + rotating_registers, + PipelineStage.EPILOGUE, + ) + + # Replace the existing uses with the new results. + new_results = arg_context.get_mapped_results(existing_get_results) + assert len(new_results) == len(existing_get_results) + for i, get_result in enumerate(existing_get_results): + replace_uses_in(existing_users, get_result, new_results[i]) + + if visualize: + visualize_mapped_graphs( + pipelined_reduction.graph, + rotating_registers, + arg_context.argument_map, + "epilogue.png", + ) + + +def construct_pipelined_loop( + trace: CapturedTrace, + reduction: Reduction, + graph: fx.Graph, + constraints: list[Constraint], + scheduler: ModuloScheduler, + node_map: dict[fx.Node, fx.Node], + max_induction_variable: int, + visualize: bool = False, +) -> fx.Node: + """ + Given a graph annotated with scheduling parameters, construct a pipelined loop + with a prologue, kernel and epilogue. + """ + induction_variable = get_induction_variable(reduction, constraints) + num_rotating_registers = liveness_analysis(graph, constraints, scheduler) + rotating_registers: dict[fx.Node, deque[fx.Node]] = { + k: deque([None for _ in range(v)]) for k, v in num_rotating_registers.items() + } + partitioned_graph = partition_graph_by_stage(graph, scheduler) + # Construct prologue. + construct_prologue( + graph, + reduction, + partitioned_graph, + scheduler, + rotating_registers, + induction_variable, + list(range(scheduler.num_stages)), + create_fill_stage_schedule(scheduler.num_stages), + ) + # Construct kernel. + pipelined_reduction, pipelined_reduction_graph = construct_kernel( + graph, + reduction, + partitioned_graph, + scheduler, + rotating_registers, + induction_variable, + [induction_variable + i for i in range(scheduler.num_stages)], + node_map, + visualize, + ) + trace.add_subgraph( + get_custom(pipelined_reduction).subgraph_name, pipelined_reduction_graph + ) + # Construct epilogue. + construct_epilogue( + graph, + reduction, + get_custom(pipelined_reduction), + partitioned_graph, + scheduler, + rotating_registers, + induction_variable, + [ + max_induction_variable - scheduler.num_stages + i + for i in range(scheduler.num_stages) + ], + create_drain_stage_schedule(scheduler.num_stages), + num_rotating_registers, + node_map, + visualize, + ) + + # Remove the unpipelined reduction. + reduction.graph.erase_node(reduction.fx_node) + + if visualize: + visualize_graph(pipelined_reduction.graph, "pipelined.png") + + return pipelined_reduction diff --git a/shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py b/shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py new file mode 100644 index 000000000..b6993a216 --- /dev/null +++ b/shark_turbine/kernel/wave/scheduling/loop_reconstruction_utils.py @@ -0,0 +1,285 @@ +from ..constraints import Constraint, TilingConstraint +from ..._support.indexing import IndexSymbol +from ..._support.tracing import CapturedTrace +from ...ops.wave_ops import Reduction, IterArg, Output, Write, GetResult, get_custom +from .modulo_scheduling import ModuloScheduler +from ..utils import graph_copy, erase_graph +from ..utils import subs_idxc +import torch.fx as fx +import math +from collections import defaultdict, deque, ChainMap +from ..visualization import visualize_mapped_graphs +from ....support.logging import get_logger +from ...lang.global_symbols import SHARED_ADDRESS_SPACE +import random +from typing import Optional + +logger = get_logger("turbine.wave.scheduling.loop_reconstruction_utils") + + +class ArgumentContext: + """ + The argument context is used to store the mapping of arguments + for each modulo pipelining stage. + """ + + def __init__( + self, + results: list[fx.Node], + iter_args: list[fx.Node], + init_args: list[fx.Node], + num_stages: int, + ) -> None: + self.argument_map: list[list[dict[fx.Node, fx.Node]]] = [ + [{} for _ in range(num_stages)] for _ in range(num_stages) + ] + self.results = results + self.iter_args = iter_args + self.init_args = init_args + self.num_stages = num_stages + self.num_iterations = num_stages + self.result_to_iter_arg: dict[fx.Node, fx.Node] = {} + self.result_to_init_arg: dict[fx.Node, fx.Node] = {} + + for result, iter_arg in zip(results, iter_args): + self.result_to_iter_arg[result] = iter_arg + for result, init_arg in zip(results, init_args): + self.result_to_init_arg[result] = init_arg + + def map_arg_all(self, from_: fx.Node, to_: fx.Node) -> None: + """ + Maps the given argument from one to another into the argument context for all stages + and for all iterations. + """ + for iteration in range(self.num_iterations): + for stage in range(self.num_stages): + self.argument_map[iteration][stage][from_] = to_ + + def map_arg_all_iterations(self, stage: int, from_: fx.Node, to_: fx.Node) -> None: + """ + Maps the given argument from one to another into the argument context for all stages + and for all iterations. + """ + for iteration in range(self.num_iterations): + self.argument_map[iteration][stage][from_] = to_ + + def get_mapped_results(self, get_results: list[GetResult]) -> list[fx.Node]: + """ + Gets the mapped results from the last iteration. If the result is not + in the last iteration, then get it from the get result nodes. + """ + mapped_results = [] + for result, get_result in zip(self.results, get_results): + stage = result.scheduling_parameters["stage"] + if result not in self.argument_map[self.num_iterations - 1][stage]: + mapped_results.append(get_result.fx_node) + else: + mapped_results.append( + self.argument_map[self.num_iterations - 1][stage][result] + ) + return mapped_results + + def get_kernel_iteration(self, stage: int) -> int: + """ + Get the iteration from the stage for the kernel. + """ + return self.num_stages - 1 - stage + + def get_kernel_results(self) -> list[fx.Node]: + """ + Gets the mapped results for the kernel. Here there + exists a fixed relationship between the iteration and stage. + """ + mapped_results = [] + for result in self.results: + stage = result.scheduling_parameters["stage"] + iteration = self.get_kernel_iteration(stage) + mapped_results.append(self.argument_map[iteration][stage][result]) + return mapped_results + + def __setitem__(self, key: tuple[int, fx.Node], value: fx.Node) -> None: + """ + Sets the argument mapping for the given stage. + """ + assert isinstance(key, tuple), "Argument context key must be a tuple" + iteration, stage, from_ = key + assert iteration < len( + self.argument_map + ), f"Iteration {iteration} not yet initialized" + assert stage < len(self.argument_map), f"Stage {stage} not yet initialized" + self.argument_map[iteration][stage][from_] = value + + def __getitem__(self, value: tuple[int, fx.Node]) -> fx.Node: + """ + Gets the argument mapping for the given stage. + """ + assert isinstance(value, tuple), "Argument context key must be a tuple" + iteration, stage, key = value + assert iteration < len( + self.argument_map + ), f"Iteration {iteration} not yet initialized" + assert stage < len(self.argument_map), f"Stage {stage} not yet initialized" + return self.argument_map[iteration][stage].get(key, None) + + def __contains__(self, key: fx.Node | tuple[int, fx.Node]) -> bool: + """ + Checks if the argument context contains the given node at a specified + iteration and stage or at all iterations and stages. + """ + if isinstance(key, tuple): + iteration, stage, key = key + return key in self.argument_map[iteration][stage] + return any( + key in self.argument_map[iteration][stage] + for iteration in range(self.num_iterations) + for stage in range(self.num_stages) + ) + + def lookup(self, key: fx.Node) -> Optional[fx.Node]: + """ + Looks up the argument mapping for the given node. + """ + for iteration in range(self.num_iterations): + for stage in range(self.num_stages): + if key in self.argument_map[iteration][stage]: + return self.argument_map[iteration][stage][key] + return None + + def contains_in_iteration(self, iteration: int, key: fx.Node) -> bool: + """ + Checks if the argument context contains the given node at a specified + iteration. + """ + return any( + key in self.argument_map[iteration][stage] + for stage in range(self.num_stages) + ) + + def get_from_iteration(self, iteration: int, key: fx.Node) -> fx.Node: + """ + Gets the argument mapping for the given iteration. + """ + for stage in range(self.num_stages): + if key in self.argument_map[iteration][stage]: + return self.argument_map[iteration][stage][key] + return None + + def dump(self): + """ + Dump the argument context to the logger. + """ + for iteration in range(self.num_iterations): + for stage in range(self.num_stages): + logger.debug(f"Iteration: {iteration}, Stage: {stage}") + for key, value in self.argument_map[iteration][stage].items(): + logger.debug(f" {key} -> {value}") + + +def create_fill_stage_schedule(n: int) -> list[list[int]]: + """ + Create the schedule of which stages need to be interleaved for the prologue (fill). + This looks like: + [0 None None None] + [1 0 None None] + [2 1 0 None] + """ + schedule = [] + for i in range(n - 1): + row = list(range(i, -1, -1)) + row.extend([None] * (n - i - 1)) + schedule.append(row) + return schedule + + +def create_drain_stage_schedule(n: int) -> list[list[int]]: + """ + Create the schedule of which stages need to be interleaved for the epilogue (drain). + This looks like: + [None 3 2 1] + [None None 3 2] + [None None None 3] + """ + schedule = [] + for i in range(n - 1): + row = [None] * (i + 1) + row.extend(range(n - 1, i, -1)) + schedule.append(row) + return schedule + + +def liveness_analysis( + graph: fx.Graph, constraints: list[Constraint], scheduler: ModuloScheduler +) -> dict[fx.Node, int]: + """ + Perform liveness analysis on the graph to determine the live ranges of + variables and use that to deduce how many rotating registers we need. + """ + lifetime: dict[fx.Node, int] = {} + for node in graph.nodes: + custom = get_custom(node) + if custom.scheduling_parameters is None: + continue + if node not in lifetime: + lifetime[node] = 0 + for user in custom.users: + if user.scheduling_parameters is None: + continue + logger.debug( + f"Node: {node}, User: {user.fx_node}, lifetime: {user.scheduling_parameters['stage'] - custom.scheduling_parameters['stage']}" + ) + lifetime[node] = max( + user.scheduling_parameters["stage"] + - custom.scheduling_parameters["stage"], + lifetime[node], + ) + + # Determine how many copies we need for each node. If the lifetime of a node + # is l clocks and the initiation interval is T, then only ceil(l/T) values + # of the node can be live at the same time. We need to create copies of only + # those nodes that are live at more than one stage. + num_rotating_registers: dict[fx.Node, int] = {} + for node, l in lifetime.items(): + if node in num_rotating_registers: + continue + custom = get_custom(node) + if ( + isinstance(custom, Write) + and custom.memory_type.address_space == SHARED_ADDRESS_SPACE + ): + continue + if l > 0: + num_rotating_registers[node] = l + + return num_rotating_registers + + +def partition_graph_by_stage( + graph: fx.Graph, scheduler: ModuloScheduler +) -> list[dict[int, list[fx.Node]]]: + """ + Partition the graph into stages based on the scheduling parameters. + """ + partitioned_graph: list[dict[int, list[fx.Node]]] = [ + defaultdict(list) for _ in range(scheduler.num_stages) + ] + for stage in range(scheduler.num_stages): + for node in graph.nodes: + custom = get_custom(node) + if custom.scheduling_parameters is None: + continue + if isinstance(custom, IterArg): + continue + if custom.scheduling_parameters["stage"] == stage: + cycle = custom.scheduling_parameters["cycle"] + partitioned_graph[stage][cycle].append(node) + return partitioned_graph + + +def interleave_instructions(instructions: list[tuple[int, int, fx.Node]]): + """ + Interleave the instructions that are scheduled in the same cycle. + Currently, we just randomly shuffle them, but we could also sort + them based on some criteria. + """ + rng = random.Random(0) + # rng.shuffle(instructions) diff --git a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py b/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py index f2abbd132..82940113e 100644 --- a/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py +++ b/shark_turbine/kernel/wave/scheduling/modulo_scheduling.py @@ -18,6 +18,7 @@ ) from typing import Callable import numpy as np +import math logger = get_logger("turbine.wave.modulo_scheduling") @@ -263,3 +264,11 @@ def resource_reservations(self) -> np.array: Returns the resource reservations of the schedule. """ return self.RT + + @property + def num_stages(self) -> int: + """ + Returns the number of stages in the kernel of the pipelined loop. + """ + max_cycle = max([t for t in self.schedule.values()]) + return math.ceil(max_cycle / self.initiation_interval) diff --git a/shark_turbine/kernel/wave/scheduling/resources.py b/shark_turbine/kernel/wave/scheduling/resources.py index 13e806874..e46bd5cf2 100644 --- a/shark_turbine/kernel/wave/scheduling/resources.py +++ b/shark_turbine/kernel/wave/scheduling/resources.py @@ -24,6 +24,9 @@ class Operation(Enum): READ_GLOBAL = "read_global" WRITE_GLOBAL = "write_global" MMA = "mma" + ALU = "alu" + VALU = "valu" + SALU = "salu" NOOP = "noop" diff --git a/shark_turbine/kernel/wave/scheduling/schedule.py b/shark_turbine/kernel/wave/scheduling/schedule.py index a03ad0823..8740fa083 100644 --- a/shark_turbine/kernel/wave/scheduling/schedule.py +++ b/shark_turbine/kernel/wave/scheduling/schedule.py @@ -11,8 +11,12 @@ from .graph_utils import create_scheduling_edges, Edge from .resources import get_available_resources, annotate_resource_usage from ..visualization import visualize_edges, visualize_graph, visualize_schedule -from ..utils import subs_idxc, graph_copy, erase_graph +from .loop_reconstruction import construct_pipelined_loop +from ..utils import graph_copy, erase_graph, get_tiling_constraint, subs_idxc import torch.fx as fx +from ....support.logging import get_logger + +logger = get_logger("turbine.wave.scheduling.schedule") def visualize_scheduling_graph(edges: list[Edge]): @@ -21,7 +25,7 @@ def visualize_scheduling_graph(edges: list[Edge]): def schedule_reduction( reduction: Reduction, trace: CapturedTrace, constraints: list[Constraint] -): +) -> dict[fx.Node, int]: """ Clones the reduction graph and does the following: 1. Annotates resource usage for each node. @@ -68,8 +72,35 @@ def schedule_reduction( erase_graph(graph) + # After scheduling has completed, we have enough information to decide + # whether to pipeline the loop. For pipelining to be possible, we need + # to have atleast N iterations of the loop where N > num_stages - 1 (because + # we will be peeling off num_stages iterations from the loop). + tiling_constraint = get_tiling_constraint(reduction, constraints) + max_induction_variable = int( + subs_idxc(tiling_constraint.dim) // subs_idxc(tiling_constraint.tile_size) + ) + if max_induction_variable <= scheduler.num_stages - 1: + logger.warn("Not enough iterations to pipeline the loop. Skipping pipelining.") + return {} + + new_reduction = construct_pipelined_loop( + trace, + reduction, + reduction_graph, + constraints, + scheduler, + node_map, + max_induction_variable, + visualize, + ) + + return {new_reduction: max_induction_variable - (scheduler.num_stages - 1)} + -def schedule_graph(trace: CapturedTrace, constraints: list[Constraint]): +def schedule_graph( + trace: CapturedTrace, constraints: list[Constraint] +) -> dict[fx.Node, int]: """ Given a graph, pipelines the reductions in the graph. """ @@ -81,5 +112,9 @@ def is_reduction(node: fx.Node) -> bool: if not reduction_nodes: return + scheduling_metadata = {} for reduction_node in reduction_nodes: - schedule_reduction(get_custom(reduction_node), trace, constraints) + scheduling_metadata.update( + schedule_reduction(get_custom(reduction_node), trace, constraints) + ) + return scheduling_metadata diff --git a/shark_turbine/kernel/wave/utils.py b/shark_turbine/kernel/wave/utils.py index affd5fef5..69dbd3626 100644 --- a/shark_turbine/kernel/wave/utils.py +++ b/shark_turbine/kernel/wave/utils.py @@ -16,8 +16,8 @@ from .._support.tracing import CapturedTrace from .._support.indexing import IndexExpr, IndexingContext, IndexSymbol, IndexSequence from ..lang.global_symbols import * -from ..ops.wave_ops import get_custom, Output, Write, MMA -from .constraints import Constraint, HardwareConstraint, TilingConstraint +from ..ops.wave_ops import get_custom, Output, Write, Reduction, MMA, CustomOp +from .constraints import HardwareConstraint, TilingConstraint, Constraint import torch.fx as fx import shark_turbine.kernel.lang as tkl @@ -90,6 +90,21 @@ def print_trace(trace: CapturedTrace, custom_print: bool = True): print(get_custom(node)) +def print_subgraph(trace: CapturedTrace, subgraph_name: str, custom_print: bool = True): + """ + Prints a specific subgraphs of a trace. + The graphs are printed first in the torch printing format and + then using our custom node format. + """ + # The root graph is at the back so we print the subgraphs in reverse order + for name, subgraph in trace.region_graph.subgraphs.items(): + if name == subgraph_name: + print(subgraph) + if custom_print: + for node in subgraph.nodes: + print(get_custom(node)) + + def DCE(trace: CapturedTrace): """ Removes all operators that are not used in the graph, @@ -378,3 +393,66 @@ def erase_graph(graph: fx.Graph): for user in node.users: graph.erase_node(user) graph.erase_node(node) + + +def get_induction_variable( + reduction: Reduction, constraints: list[Constraint] +) -> IndexSymbol: + induction_var = None + for constraint in constraints: + if ( + isinstance(constraint, TilingConstraint) + and reduction.axis == constraint.dim + ): + induction_var = constraint.induction_var + break + else: + raise ValueError(f"Could not find induction variable for reduction {reduction}") + return induction_var + + +def get_tiling_constraint( + reduction: Reduction, constraints: list[Constraint] +) -> TilingConstraint: + for constraint in constraints: + if ( + isinstance(constraint, TilingConstraint) + and reduction.axis == constraint.dim + ): + return constraint + else: + raise ValueError(f"Could not find tiling constraint for reduction {reduction}") + + +def replace_uses_in(users: dict[fx.Node, list[CustomOp]], old: CustomOp, new: fx.Node): + """ + Replace all uses of `old` with `new` in the list of users. + """ + for user in users[old]: + for i, arg in enumerate(user.fx_node.args): + if arg == old.fx_node: + user.update_arg(i, new) + + +def get_scheduling_mask(operation: Operation) -> int: + """ + Returns the scheduling mask for the given operation. + """ + match operation: + case Operation.READ_GLOBAL: + return int("0x20", 0) + case Operation.WRITE_GLOBAL: + return int("0x40", 0) + case Operation.READ_SHARED: + return int("0x100", 0) + case Operation.WRITE_SHARED: + return int("0x200", 0) + case Operation.MMA: + return int("0x8", 0) + case Operation.ALU: + return int("0x1", 0) + case Operation.VALU: + return int("0x2", 0) + case Operation.SALU: + return int("0x4", 0) + return None diff --git a/shark_turbine/kernel/wave/visualization.py b/shark_turbine/kernel/wave/visualization.py index 924c36bdf..d6438bfce 100644 --- a/shark_turbine/kernel/wave/visualization.py +++ b/shark_turbine/kernel/wave/visualization.py @@ -11,6 +11,8 @@ graphviz_disabled = True from torch import fx from .scheduling.graph_utils import Edge +from ..ops.wave_ops import Output, Placeholder, IterArg, get_custom +from collections import ChainMap import math @@ -27,6 +29,9 @@ def visualize_graph(graph: fx.Graph, file_name: str): G.add_node(node_numbering[id(node)], label=node.name) for node in graph.nodes: for user in node.users.keys(): + # Handle scenario where nodes are shared across graphs. + if user not in graph.nodes: + continue G.add_edge(node_numbering[id(node)], node_numbering[id(user)]) G.layout(prog="dot") G.draw(file_name) @@ -71,7 +76,7 @@ def visualize_schedule( for key, value in schedule.items(): table[value + stage * initiation_interval][stage] += f"{key}
" - df = pd.DataFrame(table, columns=[f"Stage {i}" for i in range(cols)]) + df = pd.DataFrame(table, columns=[f"Iteration {i}" for i in range(cols)]) s = df.style.set_properties(**{"text-align": "center"}) s = s.set_table_styles( [ @@ -95,3 +100,91 @@ def visualize_schedule( ).to_html() with open(f"{file_name}", "w") as f: f.write(output) + + +def visualize_mapped_graphs( + second: fx.Graph, + rotating_registers: dict[fx.Node, list[fx.Node]], + mappings: list[list[dict[fx.Node, fx.Node]]], + file_name: str, +): + """ + Given the pipelined graph and a list of mappings of nodes from the original + graph to the pipelined graph (per stage), visualize the pipelined graph (with their original labels) + + """ + + if graphviz_disabled: + raise ImportError("pygraphviz not installed, cannot visualize graph") + second_numbering = number_nodes(second) + + flat_inverse_map: dict[fx.Node, fx.Node] = {} + flat_map: dict[fx.Node, fx.Node] = {} + for iteration_mapping in mappings: + for mapping in iteration_mapping: + flat_inverse_map.update({v: k for k, v in mapping.items()}) + flat_map.update(mapping) + flat_inverse_map = ChainMap(flat_inverse_map) + flat_map = ChainMap(flat_map) + + # Draw nodes and edges in the pipelined graph. + G = pgv.AGraph(directed=True) + G0 = G.add_subgraph(name="pipelined") + stage: dict[fx.Node, int] = {} + for node in second.nodes: + if hasattr(node, "scheduling_parameters"): + if node in flat_inverse_map: + name = flat_inverse_map[node].name + else: + name = node.name + else: + name = node.name + G0.add_node( + second_numbering[id(node)], + label=name, + color="lightblue", + style="filled", + ) + for user in node.users.keys(): + if user not in second.nodes: + continue + if isinstance(get_custom(user), Output): + continue + G0.add_edge( + second_numbering[id(node)], + second_numbering[id(user)], + color="black", + ) + + # Draw nodes and edges in the original graph. + colors = ["red", "green", "orange", "purple", "orange", "cyan", "magenta"] + max_stage = len(mappings) + for node, mapped_node in flat_map.items(): + for user in node.users.keys(): + if user not in flat_map: + continue + mapped_user = flat_map[user] + if mapped_user not in second.nodes or mapped_node not in second.nodes: + continue + stage = "" + if hasattr(user, "scheduling_parameters"): + stage = user.scheduling_parameters["stage"] + G.add_edge( + second_numbering[id(mapped_node)], + second_numbering[id(mapped_user)], + label=f"{stage}", + color=colors[stage % max_stage], + ) + + # Draw edges between rotating registers for the same variable. + for node in rotating_registers: + all_registers = [k for k, v in flat_inverse_map.items() if v == node] + for second, first in zip(all_registers[:-1], all_registers[1:]): + G.add_edge( + second_numbering[id(first)], + second_numbering[id(second)], + color="blue", + ) + + G.layout(prog="dot") + G.draw(file_name) diff --git a/shark_turbine/kernel/wave/wave.py b/shark_turbine/kernel/wave/wave.py index eb6003de3..ef1dbd7ff 100644 --- a/shark_turbine/kernel/wave/wave.py +++ b/shark_turbine/kernel/wave/wave.py @@ -221,8 +221,9 @@ def _trace_and_get_kernel_signature( decompose_reduce_ops(graph, self.constraints, idxc.subs) # Schedule the reduction ops. + scheduling_metadata = {} if kwargs.get("schedule", False): - schedule_graph(graph, self.constraints) + scheduling_metadata = schedule_graph(graph, self.constraints) # Add shared memory barriers. add_shared_memory_barriers(graph) @@ -250,7 +251,9 @@ def _trace_and_get_kernel_signature( entrypoint_name, kernel_sig, grid, workgroup_size, subgroup_size ) - emitter = WaveEmitter(dispatch_entrypoint, graph, self.constraints) + emitter = WaveEmitter( + dispatch_entrypoint, graph, self.constraints, scheduling_metadata + ) emitter.emit(graph.get_root_graph()) emitter.finish() diff --git a/tests/kernel/wave/wave_gemm_test.py b/tests/kernel/wave/wave_gemm_test.py index 344032a4d..2386ebd96 100644 --- a/tests/kernel/wave/wave_gemm_test.py +++ b/tests/kernel/wave/wave_gemm_test.py @@ -15,6 +15,7 @@ from shark_turbine.kernel.wave.iree_utils import generate_iree_ref import os import json +from torch.testing import assert_close _run_e2e = int(os.environ.get("WAVE_RUN_E2E_TESTS", 0)) require_e2e = pytest.mark.skipif(not _run_e2e, reason="e2e tests are disabled") @@ -40,7 +41,8 @@ def get_test_shapes(test_name: str) -> list[tuple[int]]: @require_e2e @pytest.mark.parametrize("shape", get_test_shapes("test_gemm")) -def testGemm(shape: tuple[int]): +@pytest.mark.parametrize("enable_scheduling", [False, True]) +def testGemm(shape: tuple[int], enable_scheduling: bool): # Input sizes M = tkl.sym.M @@ -106,10 +108,22 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: M: shape[0], N: shape[1], K: shape[2], + READ_SHARED_DELAY: 1, + WRITE_SHARED_DELAY: 1, + READ_GLOBAL_DELAY: 2, + WRITE_GLOBAL_DELAY: 2, + MMA_DELAY: 1, + SHARED_MEMORY_UNITS: 4, + GLOBAL_MEMORY_UNITS: 4, + MMA_UNITS: 4, } config = {"backend": "rocm", "device": "hip", "target": "gfx942"} with tk.gen.TestLaunchContext( - hyperparams, canonicalize=True, run=True, run_config=config + hyperparams, + canonicalize=True, + run=True, + run_config=config, + schedule=enable_scheduling, ): a = torch.randn(shape[0], shape[2], dtype=torch.float16) b = torch.randn(shape[1], shape[2], dtype=torch.float16) @@ -123,9 +137,4 @@ def repeat(acc: tkl.Register[M, N, tkl.f32]) -> tkl.Register[M, N, tkl.f32]: iree_ref = torch.zeros(shape[0], shape[1], dtype=torch.float32) generate_iree_ref("mmt", [a, b], [iree_ref], config) - assert torch.equal(c, iree_ref) - - -if __name__ == "__main__": - logging.basicConfig(level=logging.DEBUG) - unittest.main() + assert_close(c, iree_ref)