From c3b6f22845f07ddc3ddbfce37fee25ebce602cf9 Mon Sep 17 00:00:00 2001 From: Shiv Date: Sat, 11 Jan 2025 18:16:28 -0800 Subject: [PATCH] add experimental matcher --- src/targets/gpu/fuse_mlir.cpp | 46 +++++++++++++++++++++++++++++++++++ 1 file changed, 46 insertions(+) diff --git a/src/targets/gpu/fuse_mlir.cpp b/src/targets/gpu/fuse_mlir.cpp index 65a27a76ad1..b8abc72c0c4 100644 --- a/src/targets/gpu/fuse_mlir.cpp +++ b/src/targets/gpu/fuse_mlir.cpp @@ -911,6 +911,46 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op } }; +struct find_mlir_reshape_pointwise +{ + mlir_mode conv_mode = mlir_mode::none; + mlir_mode dot_mode = mlir_mode::none; + auto matcher() const + { + auto reshapes = reshaper_names(); + // Dont want to move ops before slice + reshapes.erase("slice"); + auto rsp_dot_or_conv = + match::name(reshapes)( + match::any_of[match::inputs()](is_mlir_dot(dot_mode), is_mlir_conv(conv_mode)) + .bind("mlir_op")) + .bind("rsp"); + return mlir_pointwise()(match::any_of[match::inputs()](rsp_dot_or_conv)); + } + + void apply(module_pass_manager& mpm, const match::matcher_result& r) const + { + auto ins = r.result; + auto rsp = r.instructions["rsp"]; + + if(not std::all_of(ins->inputs().begin(), ins->inputs().end(), [&](auto i) { + return i->get_operator() == rsp->get_operator(); + })) + return; + + std::vector new_inps; + std::transform(ins->inputs().begin(), + ins->inputs().end(), + std::back_inserter(new_inps), + [&](auto i) { return i->inputs().front(); }); + + auto new_pw = mpm.get_module().insert_instruction( + ins, ins->get_operator(), new_inps, ins->module_inputs()); + + mpm.get_module().replace_instruction(ins, rsp->get_operator(), {new_pw}); + } +}; + struct find_pointwise_mlir { auto supported_pointwise() const { return mlir_input_pointwise(match::used_once()); } @@ -1061,6 +1101,12 @@ void fuse_mlir::apply(module_pass_manager& mpm) const return std::max(m1, m2); }; + match::find_matches( + mpm, + find_mlir_reshape_pointwise{.conv_mode = get_mode("fused_convolution", mlir_mode::fast), + .dot_mode = get_mode("fused_dot", mlir_mode::fast)}); + mpm.run_pass(dead_code_elimination{}); + // Attention offloads; default disabled if(mlir_attention_enabled(ctx) or enable_extra) {