Skip to content

Commit

Permalink
Layout convolution as NHWC or NCHW only
Browse files Browse the repository at this point in the history
  • Loading branch information
pfultz2 committed Dec 20, 2024
1 parent 52e204d commit 5be9615
Showing 1 changed file with 7 additions and 3 deletions.
10 changes: 7 additions & 3 deletions src/layout_convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,14 +40,18 @@ inline namespace MIGRAPHX_INLINE_NS {
namespace {
std::vector<int64_t> get_permutation(instruction_ref ins, const layout_convolution& lc)
{
std::vector<int64_t> perm(ins->get_shape().ndim());
if(lc.channels_last)
{
std::vector<int64_t> perm(ins->get_shape().ndim());
std::iota(perm.begin() + 1, perm.end() - 1, 2);
perm.back() = 1;
return perm;
}
return find_permutation(ins->inputs().front()->get_shape());
else
{
std::iota(perm.begin(), perm.end(), 0);
}
return perm;

}

bool skip_layout(const shape& s)
Expand Down

0 comments on commit 5be9615

Please sign in to comment.