Skip to content

Commit

Permalink
Merge branch 'develop' into rewrite-dot
Browse files Browse the repository at this point in the history
  • Loading branch information
causten authored Dec 19, 2024
2 parents 4b59c2d + a678b48 commit 9600e10
Show file tree
Hide file tree
Showing 9 changed files with 750 additions and 124 deletions.
101 changes: 60 additions & 41 deletions src/include/migraphx/rewrite_reshapes.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
#include <migraphx/simplify_reshapes.hpp>
#include <migraphx/eliminate_common_subexpression.hpp>
#include <migraphx/dead_code_elimination.hpp>
#include <migraphx/shape_transform_descriptor.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
Expand Down Expand Up @@ -72,18 +73,19 @@ struct rewrite_reshapes

auto matcher() const
{
auto reshape =
match::name("reshape", "squeeze", "unsqueeze", "flatten")(match::used_once());
auto skip_contiguous_broadcast =
match::skip(match::name("contiguous", "multibroadcast")(match::used_once()));
auto skip_contiguous_broadcast_arg = [&](auto... ms) {
return match::arg(0)(skip_contiguous_broadcast(ms...));
};
auto reshapes = match::name("reshape",
"squeeze",
"unsqueeze",
"flatten",
"transpose",
"contiguous",
"multibroadcast",
"broadcast")(match::used_once());
auto pointwise = match::name(op1)(match::used_once());
auto reshape_pointwise =
reshape(skip_contiguous_broadcast_arg(pointwise.bind("x"))).bind("reshape");
return match::name(op2)(match::any_of[match::inputs()](
skip_contiguous_broadcast(reshape_pointwise).bind("input")));
auto reshapes_pointwise =
reshapes(match::arg(0)(match::skip(reshapes())(pointwise.bind("x"))));
return match::name(op2)(
match::any_of[match::inputs()](reshapes_pointwise.bind("input")));
}

template <class F>
Expand All @@ -100,6 +102,12 @@ struct rewrite_reshapes
return last;
}

template <class F>
static bool any_input_of(instruction_ref start, instruction_ref last, F f)
{
return find_input_if(start, last, f) != last;
}

static bool match_input(instruction_ref ins, instruction_ref x_ins)
{
if(ins->inputs().empty())
Expand All @@ -120,65 +128,76 @@ struct rewrite_reshapes
return result;
}

static bool is_broadcast(instruction_ref ins) { return ins->name() == "multibroadcast"; }

void apply(module_pass_manager& mpm, const match::matcher_result& r) const
{
auto ins = r.result;
auto x_ins = r.instructions["x"];
auto reshape_ins = r.instructions["reshape"];
auto input_ins = r.instructions["input"];

const auto has_broadcast_before_reshape = is_broadcasted(reshape_ins, x_ins);
const auto has_broadcast_after_reshape = is_broadcasted(input_ins, reshape_ins);
if(not has_broadcast_before_reshape.has_value())
return;
if(not has_broadcast_after_reshape.has_value())
// If its just a broadcast then skip
if(not any_input_of(input_ins, x_ins, [](instruction_ref x) {
return not contains({"multibroadcast", "broadcast", "contiguous"}, x->name());
}))
return;
if(*has_broadcast_after_reshape and *has_broadcast_before_reshape)
return;
const bool has_broadcast =
*has_broadcast_after_reshape or *has_broadcast_before_reshape;

auto dims1 = T::base_dims(ins);
auto dims2 = T::base_dims(x_ins);

if(elements(dims1) != elements(dims2))
return;

auto cd = common_dims::compute(T::base_dims(ins), T::base_dims(x_ins));
if(cd.dims.empty())
return;
std::vector<operation> ops;
auto next_ins = input_ins;
while(next_ins != x_ins)
{
ops.push_back(next_ins->get_operator());
next_ins = next_ins->inputs().front();
}
assert(next_ins == x_ins);
std::reverse(ops.begin(), ops.end());

if(ins->name() != "pointwise" and not T::supports(ins, cd.dims, cd.axes_map1))
auto desc =
shape_transform_descriptor::create(x_ins->get_shape().lens(), ops).rebase(dims2);
if(desc.empty())
return;
if(x_ins->name() != "pointwise" and not T::supports(x_ins, cd.dims, cd.axes_map2))
return;

auto reshape_input = [&](const auto& ins_to_insert) {
return [&](auto input) {
auto dims = cd.get_dimensions_for(input->get_shape().lens());
return mpm.get_module().insert_instruction(
ins_to_insert, make_op("reshape", {{"dims", dims}}), input);
auto cdims = desc.common_dims();
auto reshape_input = [&](const auto& ins_to_insert, auto generate) {
return [&, generate](auto input) {
auto gops = std::invoke(generate, desc, input->get_shape().lens());
auto start = input;
for(const auto& op : gops)
{
start = mpm.get_module().insert_instruction(ins_to_insert, op, start);
}
return start;
};
};
auto x_inputs = x_ins->inputs();
std::transform(
x_inputs.begin(), x_inputs.end(), x_inputs.begin(), reshape_input(x_ins));
auto new_x_ins = insert(mpm, x_ins, x_inputs, cd.axes_map2);
if(has_broadcast)
x_inputs.begin(),
x_inputs.end(),
x_inputs.begin(),
reshape_input(x_ins, &shape_transform_descriptor::generate_common_from_src));
auto new_x_ins = insert(mpm, x_ins, x_inputs, desc.common_axes_map_from_src());
if(new_x_ins->get_shape().lens() != cdims)
{
new_x_ins = mpm.get_module().insert_instruction(
x_ins, make_op("multibroadcast", {{"out_lens", cd.dims}}), new_x_ins);
x_ins, make_op("multibroadcast", {{"out_lens", cdims}}), new_x_ins);
}

auto inputs = ins->inputs();
std::transform(inputs.begin(), inputs.end(), inputs.begin(), [&](auto input) {
if(input == input_ins)
return new_x_ins;
return reshape_input(ins)(input);
return reshape_input(ins,
&shape_transform_descriptor::generate_common_from_dst)(input);
});
auto pw = insert(mpm, ins, inputs, cd.axes_map1);
mpm.get_module().replace_instruction(
ins, make_op("reshape", {{"dims", ins->get_shape().lens()}}), pw);
auto pw = insert(mpm, ins, inputs, desc.common_axes_map_from_dst());
auto rins =
reshape_input(ins, &shape_transform_descriptor::generate_dst_from_common)(pw);
mpm.get_module().replace_instruction(ins, rins);
}

static bool same_dims(instruction_ref ins)
Expand Down
24 changes: 24 additions & 0 deletions src/include/migraphx/shape_transform_descriptor.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,11 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor
shape_transform_descriptor() = default;
explicit shape_transform_descriptor(const std::vector<std::size_t>& dims);

static shape_transform_descriptor create(const std::vector<std::size_t>& dims,
const std::vector<operation>& ops);

shape_transform_descriptor rebase(const std::vector<std::size_t>& dims) const;

bool apply(const std::vector<operation>& ops);
bool apply_reshape(const std::vector<std::size_t>& rdims);
bool apply_reshape_impl(const std::vector<std::size_t>& rdims);
Expand All @@ -84,6 +89,22 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor
std::size_t elements() const;
std::vector<operation> generate() const;

bool has_broadcast() const;
void flatten_broadcast();

std::vector<std::size_t> common_dims(const std::vector<std::size_t>& input_dims = {}) const;
std::vector<operation>
generate_common_from_src(const std::vector<std::size_t>& input_dims = {}) const;
std::vector<operation>
generate_common_from_dst(const std::vector<std::size_t>& input_dims = {}) const;
std::vector<operation>
generate_dst_from_common(const std::vector<std::size_t>& input_dims = {}) const;
std::vector<std::vector<std::size_t>> common_axes_map_from_src() const;
std::vector<std::vector<std::size_t>> common_axes_map_from_dst() const;

bool empty() const;
std::vector<std::size_t> lens() const;

struct MIGRAPHX_EXPORT dimension
{
void simplify();
Expand All @@ -105,6 +126,9 @@ struct MIGRAPHX_EXPORT shape_transform_descriptor

void add_split_axis(std::size_t i);

void expose();
void hide();

MIGRAPHX_EXPORT friend bool operator==(const sub& x, const sub& y);
MIGRAPHX_EXPORT friend bool operator!=(const sub& x, const sub& y);
MIGRAPHX_EXPORT friend std::ostream& operator<<(std::ostream& os, const sub& x);
Expand Down
Loading

0 comments on commit 9600e10

Please sign in to comment.