Skip to content

Commit

Permalink
[luci/pass] Introduce ArrayIndex helper (#14565)
Browse files Browse the repository at this point in the history
Let's introduce Array4DIndex to help calculating 4D array index.

ONE-DCO-Signed-off-by: Dayoung Lee <[email protected]>
Co-authored-by: Hyukjin Jeong <[email protected]>
Co-authored-by: SaeHie Park <[email protected]>
  • Loading branch information
3 people authored Jan 21, 2025
1 parent f9d691b commit c305c84
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 0 deletions.
63 changes: 63 additions & 0 deletions compiler/luci/pass/src/helpers/ArrayIndex.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ArrayIndex.h"

#include <cassert>
#include <stdexcept>

namespace luci
{

#define THROW_UNLESS(COND) \
if (not(COND)) \
throw std::invalid_argument("");

Array4DIndex::Array4DIndex(uint32_t D0, uint32_t D1, uint32_t D2, uint32_t D3)
: _dim{D0, D1, D2, D3}
{
_strides[3] = 1;
_strides[2] = D3;
_strides[1] = D3 * D2;
_strides[0] = D3 * D2 * D1;

for (int i = 0; i < 4; ++i)
{
THROW_UNLESS(_strides[i] > 0);
}
}

uint32_t Array4DIndex::operator()(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3) const
{
THROW_UNLESS(i0 < _dim[0] && i1 < _dim[1] && i2 < _dim[2] && i3 < _dim[3]);

return i0 * _strides[0] + i1 * _strides[1] + i2 * _strides[2] + i3 * _strides[3];
}

uint32_t Array4DIndex::size(void) const
{

for (int i = 0; i < 4; ++i)
{
THROW_UNLESS(_dim[i] > 0);
}

return _dim[0] * _dim[1] * _dim[2] * _dim[3];
}

uint32_t Array4DIndex::stride(uint32_t axis) const { return _strides[axis]; }

} // namespace luci
47 changes: 47 additions & 0 deletions compiler/luci/pass/src/helpers/ArrayIndex.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __LUCI_PASS_HELPERS_ARRAY_INDEX_H__
#define __LUCI_PASS_HELPERS_ARRAY_INDEX_H__

#include <cstdint>

namespace luci
{

/// @brief Index class for 4D tensor to calculate linear index from multi-dimensional indices.
class Array4DIndex final
{
public:
Array4DIndex(uint32_t d0, uint32_t d1, uint32_t d2, uint32_t d3);

/// @brief Calculate linear index from multi-dimensional indices.
uint32_t operator()(uint32_t i0, uint32_t i1, uint32_t i2, uint32_t i3) const;

/// @brief Get total number of elements in the tensor.
uint32_t size(void) const;

/// @brief Get stride of the given axis.
uint32_t stride(uint32_t axis) const;

private:
uint32_t _dim[4];
uint32_t _strides[4];
};

} // namespace luci

#endif // __LUCI_PASS_HELPERS_ARRAY_INDEX_H__
49 changes: 49 additions & 0 deletions compiler/luci/pass/src/helpers/ArrayIndex.test.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "ArrayIndex.h"

#include <gtest/gtest.h>

TEST(LuciPassHelpersArrayIndex, array_index_4d)
{
luci::Array4DIndex idx(5, 4, 3, 2);

EXPECT_EQ(idx(0, 0, 0, 0), 0);

// stride
EXPECT_EQ(idx(1, 0, 0, 0), idx.stride(0));
EXPECT_EQ(idx(0, 1, 0, 0), idx.stride(1));
EXPECT_EQ(idx(0, 0, 1, 0), idx.stride(2));
EXPECT_EQ(idx(0, 0, 0, 1), idx.stride(3));

// size
EXPECT_EQ(idx.size(), 5 * 4 * 3 * 2);

EXPECT_EQ(idx(4, 3, 2, 1), 4 * 4 * 3 * 2 + 3 * 3 * 2 + 2 * 2 + 1);
}

TEST(LuciPassHelpersArrayIndex, array_invalid_index_4d_NEG)
{
luci::Array4DIndex idx(4, 4, 3, 2);

EXPECT_ANY_THROW(idx(5, 0, 0, 0));
}

TEST(LuciPassHelpersArrayIndex, array_invalid_dim_4d_NEG)
{
EXPECT_ANY_THROW(luci::Array4DIndex idx(4, 0, 3, 2));
}

0 comments on commit c305c84

Please sign in to comment.