Skip to content

Commit

Permalink
[compute/cker] Introduce the ShapeIterator
Browse files Browse the repository at this point in the history
This commit introduces an utility that effectively makes the Shape objects iterable. It's an iterator class which points to the individual dimensions in the shape and allows the interoperability of the Shape class and STL algorithms as well as range-based for loops. The iterator fulfills the requirements of a bidirectional iterator.

In addition this commit contains one extra utility which allows the Shape objects conversion to std::string by contatenating them with a comma.

ONE-DCO-1.0-Signed-off-by: Tomasz Dolbniak <[email protected]>
  • Loading branch information
tomdol committed Nov 15, 2024
1 parent 1e09707 commit 52d6f6b
Show file tree
Hide file tree
Showing 3 changed files with 227 additions and 0 deletions.
95 changes: 95 additions & 0 deletions compute/cker/include/cker/ShapeIterator.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef __NNFW_CKER_SHAPE_ITERATOR_H__
#define __NNFW_CKER_SHAPE_ITERATOR_H__

#include <utility>
#include <iostream>
#include "cker/Shape.h"

namespace nnfw
{
namespace cker
{
struct ShapeIterator
{
/// Definition of this iterator's traits that can be accessed by std::iterator_traits<It>
using value_type = decltype(std::declval<Shape>().Dims(0));
using difference_type = std::ptrdiff_t;
using pointer = value_type *;
using reference = value_type &;
using iterator_category = std::bidirectional_iterator_tag;

ShapeIterator(const Shape &s) : _shape{s}, _current{0}, _last{s.DimensionsCount()} {}
static ShapeIterator end_iterator(const Shape &s) { return ShapeIterator(s, EndIteratorTag{}); }

ShapeIterator &operator++()
{
++_current;
return *this;
}

// postincrement
ShapeIterator operator++(int)
{
auto copy = *this;
++_current;
return copy;
}

ShapeIterator &operator--()
{
--_current;
return *this;
}

ShapeIterator operator--(int)
{
auto copy = *this;
--_current;
return copy;
}

bool operator!=(const ShapeIterator &other) const { return _current != other._current; }
bool operator==(const ShapeIterator &other) const { return _current == other._current; }

/// Because the underlying method returns by-value, this operator does the same
/// instead of returning by-reference like most iterators do.
value_type operator*() const { return _shape.Dims(_current); }

private:
struct EndIteratorTag
{
};
// Creates an iterator instance pointing to the past-the-end element
// This iterator doesn't point to a valid element and thus its dereference is undefined behavior
ShapeIterator(const Shape &s, EndIteratorTag)
: _shape{s}, _current{s.DimensionsCount()}, _last{s.DimensionsCount()}
{
}

const Shape &_shape;
int32_t _current = 0, _last = 0;
};

inline ShapeIterator begin(const Shape &s) { return ShapeIterator(s); }
inline ShapeIterator end(const Shape &s) { return ShapeIterator::end_iterator(s); }

} // namespace cker
} // namespace nnfw

#endif //
24 changes: 24 additions & 0 deletions compute/cker/include/cker/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
#define __NNFW_CKER_UTILS_H__

#include "Shape.h"
#include "ShapeIterator.h"

#include "neon/neon_check.h"

#include <algorithm>
#include <cstdint>
#include <numeric>
#include <string>
#include <fixedpoint/fixedpoint.h>

namespace nnfw
Expand Down Expand Up @@ -480,6 +483,27 @@ template <typename T> class SequentialTensorWriter
T *output_ptr_;
};

inline std::ostream &operator<<(std::ostream &os, const Shape &shape)
{
using std::begin;
using std::end;

std::string formatted =
std::accumulate(begin(shape), end(shape), std::string{"["},
[](std::string joined, ShapeIterator::value_type dim) {
return std::move(joined).append(std::to_string(dim)).append(",");
});

if (formatted.back() == '[') {
formatted.push_back(']');
} else {
formatted.back() = ']';
}

os << formatted;
return os;
}

} // namespace cker
} // namespace nnfw

Expand Down
108 changes: 108 additions & 0 deletions compute/cker/src/ShapeIterator.test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cker/ShapeIterator.h>
#include <cker/Utils.h>
#include <gtest/gtest.h>
#include <numeric>

using namespace nnfw::cker;

TEST(CKer_Utils, ShapeIterator_basic)
{
const Shape test_shape{1, 3, 1024, 768};
{
// test the front and back iterability with basic operators
ShapeIterator it{test_shape};
EXPECT_EQ(*it, 1);
++it;
EXPECT_EQ(*it, 3);
it++;
EXPECT_EQ(*it, 1024);
--it;
EXPECT_EQ(*it, 3);
it--;
EXPECT_EQ(*it, 1);
}
{
// test the iterator's compatibility with STL iterator functions
ShapeIterator it{test_shape};
auto it2 = std::next(it);
EXPECT_EQ(*it2, 3);
EXPECT_EQ(*it, 1); // make sure the original iterator is untouched

std::advance(it2, 2);
EXPECT_EQ(*it2, 768);

std::advance(it2, -1);
EXPECT_EQ(*it2, 1024);
}
{
// postincrement operator test
ShapeIterator it{test_shape};
const auto it2 = it++;
EXPECT_EQ(*it, 3);
EXPECT_EQ(*it2, 1);
}
{
// test the ability to iterate over a Shape with range-based loops
int expected_dims[] = {1, 3, 1024, 768};
int i = 0;
for (auto &&dim : test_shape)
{
EXPECT_EQ(dim, expected_dims[i++]);
}
}
{
// test the ability to retrieve iterators using begin & end
const auto first = begin(test_shape);
const auto last = end(test_shape);
EXPECT_GT(std::distance(first, last), 0);
EXPECT_EQ(std::distance(first, last), test_shape.DimensionsCount());
}

{
// test and demostrate the usage of iterators with STL algos
const auto first = begin(test_shape);
const auto last = end(test_shape);
const auto shape_elems = std::accumulate(first, last, 1, std::multiplies<ShapeIterator::value_type>{});
EXPECT_EQ(shape_elems, test_shape.FlatSize());
}

{
// Shape and ofstream interoperability test
std::stringstream ss;
ss << test_shape;
EXPECT_EQ(ss.str(), "[1,3,1024,768]");
}
}

TEST(CKer_Utils, neg_ShapeIterator_empty_shape)
{
const Shape test_shape{};
{
const auto first = begin(test_shape);
const auto last = end(test_shape);
EXPECT_EQ(first, last);
}

{
std::stringstream ss;
ss << test_shape;
EXPECT_EQ(ss.str(), "[]");
}
}

0 comments on commit 52d6f6b

Please sign in to comment.