Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tile channels for group norm and also fuse output reshapes in mlir #3750

Draft
wants to merge 2 commits into
base: develop
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions src/targets/gpu/fuse_mlir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1040,6 +1040,69 @@ struct find_unpack_int4_mlir_op
}
};

struct find_mlir_reshape_ops
{
auto matcher() const
{
auto reshapes = reshaper_names();
// slice is not supported
reshapes.erase("slice");
return match::name(reshapes)(match::arg(0)(match::name("gpu::mlir_op")(match::used_once())), match::used_once());
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto mlir_ins = ins->inputs().front();

auto* mm = mlir_ins->module_inputs().front();
module_ref nm = mpm.create_module(mm->name() + ":" + ins->name());
nm->set_bypass();

auto y = nm->fuse(*mm, mlir_ins->inputs());
auto ret = nm->add_instruction(ins->get_operator(), y);
nm->add_return({ret});
mpm.get_module().replace_instruction(ins, mlir_ins->get_operator(), mlir_ins->inputs(), {nm});
}
};

struct find_convolution_reshape
{
auto matcher() const
{
return match::name("reshape")(match::arg(0)(match::name("convolution").bind("convolution")));
}

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto conv = r.instructions["convolution"];
auto out_dims = ins->get_shape().lens();
auto conv_dims = conv->get_shape().lens();
if(out_dims.size() != 5)
return;
if(conv_dims.size() != 4)
return;
auto perm = find_permutation(conv->get_shape());
if(perm.back() != 1)
return;
if(out_dims[0] != conv_dims[0])
return;
if(not std::equal(conv_dims.begin() + 2, conv_dims.end(), out_dims.begin() + 3, out_dims.end()))
return;
if(out_dims[2] > 32)
return;
if(out_dims[1] < 4)
return;
auto reshape = mpm.get_module().insert_instruction(ins,ins->get_operator(), ins->inputs());
// auto t2 = mpm.get_module().insert_instruction(ins, make_op("layout", {{"permutation", {0, 1, 3, 4, 2}}}), reshape);
auto t1 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 3, 4, 2}}}), reshape);
auto c = mpm.get_module().insert_instruction(ins, make_op("contiguous"), t1);
auto t2 = mpm.get_module().insert_instruction(ins, make_op("transpose", {{"permutation", {0, 1, 4, 2, 3}}}), c);
mpm.get_module().replace_instruction(ins, t2);
}
};

} // namespace

#endif // MIGRAPHX_MLIR
Expand All @@ -1061,6 +1124,7 @@ void fuse_mlir::apply(module_pass_manager& mpm) const
return std::max(m1, m2);
};

match::find_matches(mpm, find_convolution_reshape{});
// Attention offloads; default disabled
if(mlir_attention_enabled(ctx) or enable_extra)
{
Expand Down Expand Up @@ -1092,6 +1156,9 @@ void fuse_mlir::apply(module_pass_manager& mpm) const

match::find_matches(mpm, find_pointwise_mlir{});
match::find_matches(mpm, find_unpack_int4_mlir_op{});

for(int i=0;i<4;i++)
match::find_matches(mpm, find_mlir_reshape_ops{});

#else
(void)mpm;
Expand Down
Loading