-
Notifications
You must be signed in to change notification settings - Fork 159
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[compute/cker] Introduce the ShapeIterator (#14311)
This commit introduces an utility that effectively makes the Shape objects iterable. It's an iterator class which points to the individual dimensions in the shape and allows the interoperability of the Shape class and STL algorithms as well as range-based for loops. The iterator fulfills the requirements of a bidirectional iterator. In addition this commit contains one extra utility which allows the Shape objects conversion to std::string by concatenating them with a comma. ONE-DCO-1.0-Signed-off-by: Tomasz Dolbniak <[email protected]> Co-authored-by: SeungHui Youn <[email protected]>
- Loading branch information
Showing
3 changed files
with
229 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#ifndef __NNFW_CKER_SHAPE_ITERATOR_H__ | ||
#define __NNFW_CKER_SHAPE_ITERATOR_H__ | ||
|
||
#include <utility> | ||
#include "cker/Shape.h" | ||
|
||
namespace nnfw | ||
{ | ||
namespace cker | ||
{ | ||
struct ShapeIterator | ||
{ | ||
/// Definition of this iterator's traits that can be accessed by std::iterator_traits<It> | ||
using value_type = decltype(std::declval<Shape>().Dims(0)); | ||
using difference_type = std::ptrdiff_t; | ||
using pointer = value_type *; | ||
using reference = value_type &; | ||
using iterator_category = std::bidirectional_iterator_tag; | ||
|
||
ShapeIterator(const Shape &s) : _shape{s}, _current{0}, _last{s.DimensionsCount()} {} | ||
static ShapeIterator end_iterator(const Shape &s) { return ShapeIterator(s, EndIteratorTag{}); } | ||
|
||
ShapeIterator &operator++() | ||
{ | ||
++_current; | ||
return *this; | ||
} | ||
|
||
// postincrement | ||
ShapeIterator operator++(int) | ||
{ | ||
auto copy = *this; | ||
++_current; | ||
return copy; | ||
} | ||
|
||
ShapeIterator &operator--() | ||
{ | ||
--_current; | ||
return *this; | ||
} | ||
|
||
ShapeIterator operator--(int) | ||
{ | ||
auto copy = *this; | ||
--_current; | ||
return copy; | ||
} | ||
|
||
bool operator!=(const ShapeIterator &other) const { return _current != other._current; } | ||
bool operator==(const ShapeIterator &other) const { return _current == other._current; } | ||
|
||
/// Because the underlying method returns by-value, this operator does the same | ||
/// instead of returning by-reference like most iterators do. | ||
value_type operator*() const { return _shape.Dims(_current); } | ||
|
||
private: | ||
struct EndIteratorTag | ||
{ | ||
}; | ||
// Creates an iterator instance pointing to the past-the-end element | ||
// This iterator doesn't point to a valid element and thus its dereference is undefined behavior | ||
ShapeIterator(const Shape &s, EndIteratorTag) | ||
: _shape{s}, _current{s.DimensionsCount()}, _last{s.DimensionsCount()} | ||
{ | ||
} | ||
|
||
const Shape &_shape; | ||
int32_t _current = 0, _last = 0; | ||
}; | ||
|
||
inline ShapeIterator begin(const Shape &s) { return ShapeIterator(s); } | ||
inline ShapeIterator end(const Shape &s) { return ShapeIterator::end_iterator(s); } | ||
|
||
} // namespace cker | ||
} // namespace nnfw | ||
|
||
#endif // |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
/* | ||
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include <cker/ShapeIterator.h> | ||
#include <cker/Utils.h> | ||
#include <gtest/gtest.h> | ||
#include <numeric> | ||
|
||
using namespace nnfw::cker; | ||
|
||
TEST(CKer_Utils, ShapeIterator_basic) | ||
{ | ||
const Shape test_shape{1, 3, 1024, 768}; | ||
{ | ||
// test the front and back iterability with basic operators | ||
ShapeIterator it{test_shape}; | ||
EXPECT_EQ(*it, 1); | ||
++it; | ||
EXPECT_EQ(*it, 3); | ||
it++; | ||
EXPECT_EQ(*it, 1024); | ||
--it; | ||
EXPECT_EQ(*it, 3); | ||
it--; | ||
EXPECT_EQ(*it, 1); | ||
} | ||
{ | ||
// test the iterator's compatibility with STL iterator functions | ||
ShapeIterator it{test_shape}; | ||
auto it2 = std::next(it); | ||
EXPECT_EQ(*it2, 3); | ||
EXPECT_EQ(*it, 1); // make sure the original iterator is untouched | ||
|
||
std::advance(it2, 2); | ||
EXPECT_EQ(*it2, 768); | ||
|
||
std::advance(it2, -1); | ||
EXPECT_EQ(*it2, 1024); | ||
} | ||
{ | ||
// postincrement operator test | ||
ShapeIterator it{test_shape}; | ||
const auto it2 = it++; | ||
EXPECT_EQ(*it, 3); | ||
EXPECT_EQ(*it2, 1); | ||
} | ||
{ | ||
// test the ability to iterate over a Shape with range-based loops | ||
int expected_dims[] = {1, 3, 1024, 768}; | ||
int i = 0; | ||
for (auto &&dim : test_shape) | ||
{ | ||
EXPECT_EQ(dim, expected_dims[i++]); | ||
} | ||
} | ||
{ | ||
// test the ability to retrieve iterators using begin & end | ||
const auto first = begin(test_shape); | ||
const auto last = end(test_shape); | ||
EXPECT_GT(std::distance(first, last), 0); | ||
EXPECT_EQ(std::distance(first, last), test_shape.DimensionsCount()); | ||
} | ||
|
||
{ | ||
// test and demostrate the usage of iterators with STL algos | ||
const auto first = begin(test_shape); | ||
const auto last = end(test_shape); | ||
const auto shape_elems = | ||
std::accumulate(first, last, 1, std::multiplies<ShapeIterator::value_type>{}); | ||
EXPECT_EQ(shape_elems, test_shape.FlatSize()); | ||
} | ||
|
||
{ | ||
// Shape and ofstream interoperability test | ||
std::stringstream ss; | ||
ss << test_shape; | ||
EXPECT_EQ(ss.str(), "[1,3,1024,768]"); | ||
} | ||
} | ||
|
||
TEST(CKer_Utils, neg_ShapeIterator_empty_shape) | ||
{ | ||
const Shape test_shape{}; | ||
{ | ||
const auto first = begin(test_shape); | ||
const auto last = end(test_shape); | ||
EXPECT_EQ(first, last); | ||
} | ||
|
||
{ | ||
std::stringstream ss; | ||
ss << test_shape; | ||
EXPECT_EQ(ss.str(), "[]"); | ||
} | ||
} |