Skip to content

Commit

Permalink
[cker] Make Mul support bool type. (Samsung#12145)
Browse files Browse the repository at this point in the history
This commit makes the kernel for Mul support boolean type too.
  - Add mul op unit tests
  - Support bool type

ONE-DCO-1.0-Signed-off-by: ragmani <[email protected]>
  • Loading branch information
ragmani authored Dec 6, 2023
1 parent 4fb2a38 commit c338b1e
Show file tree
Hide file tree
Showing 3 changed files with 288 additions and 2 deletions.
38 changes: 36 additions & 2 deletions compute/cker/include/cker/operation/BinaryArithmeticOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#define __NNFW_CKER_BINARY_ARITHMETIC_OPS_H__

#include <functional>
#include <stdexcept>
#include "cker/operation/optimized/BinaryArithmeticOps.h"
#include "cker/operation/reference/BinaryArithmeticOps.h"
#include "cker/Shape.h"
Expand All @@ -32,7 +33,8 @@ namespace cker

namespace
{
template <BinaryArithmeticOpType op_type, typename T>
template <BinaryArithmeticOpType op_type, typename T,
typename std::enable_if_t<!std::is_same<T, bool>::value, bool> = true>
const std::function<T(const T &, const T &)> GetBinaryArtithmeticFn()
{
switch (op_type)
Expand Down Expand Up @@ -71,6 +73,23 @@ const std::function<T(const T &, const T &)> GetBinaryArtithmeticFn()
}
}
}

template <BinaryArithmeticOpType op_type, typename T,
typename std::enable_if_t<std::is_same<T, bool>::value, bool> = true>
const std::function<T(const bool &, const bool &)> GetBinaryArtithmeticFn()
{
switch (op_type)
{
case BinaryArithmeticOpType::MUL:
{
return [](const bool &a, const bool &b) -> bool { return a && b; };
}
default:
{
throw std::runtime_error("GetBinaryArtithmeticFn: Unsupported OpType with Bool8");
}
}
}
} // namespace

// Consolidates dimensions in broadcast inputs, checks for five-fold pattern.
Expand Down Expand Up @@ -190,7 +209,7 @@ inline bool ProcessBroadcastShapes(const Shape &shape0, const Shape &shape1,
}

template <BinaryArithmeticOpType op_type, typename T>
inline typename std::enable_if_t<!is_quant8<T>::value>
inline typename std::enable_if_t<!is_quant8<T>::value && !std::is_same<T, bool>::value>
BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
const T *input1_data, const Shape &input2_shape, const T *input2_data,
const Shape &output_shape, T *output_data)
Expand All @@ -199,6 +218,16 @@ BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shape &input1_sh
output_shape, output_data, GetBinaryArtithmeticFn<op_type, T>());
}

template <BinaryArithmeticOpType op_type, typename T>
inline typename std::enable_if_t<!is_quant8<T>::value && std::is_same<T, bool>::value>
BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
const bool *input1_data, const Shape &input2_shape, const bool *input2_data,
const Shape &output_shape, bool *output_data)
{
reference::BinaryArithmeticOp(params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data, GetBinaryArtithmeticFn<op_type, bool>());
}

template <BinaryArithmeticOpType op_type, typename T>
inline typename std::enable_if_t<is_quant8<T>::value>
BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shape &input1_shape,
Expand Down Expand Up @@ -298,6 +327,11 @@ inline void BroadcastBinaryArithmeticOp(BinaryArithmeticOpParam &params, const S
const float *input2_data, const Shape &output_shape,
float *output_data)
{
if (output_shape.DimensionsCount() > 4)
throw std::runtime_error(
std::string("cker::BroadcastBinaryArithmeticOp: Unsupported rank size : ") +
std::to_string(output_shape.DimensionsCount()));

// Supported type is only float now
switch (op_type)
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,20 @@ inline void BinaryArithmeticOp(const BinaryArithmeticOpParam &params, const Shap
}
}

template <>
inline void BinaryArithmeticOp(const BinaryArithmeticOpParam &, const Shape &input1_shape,
const bool *input1_data, const Shape &input2_shape,
const bool *input2_data, const Shape &output_shape,
bool *output_data,
const std::function<bool(const bool &, const bool &)> &fn)
{
const int size = MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < size; i++)
{
output_data[i] = fn(input1_data[i], input2_data[i]);
}
}

template <typename T>
inline typename std::enable_if_t<is_quant8<T>::value> BroadcastBinaryArithmeticOpSlow(
const BinaryArithmeticOpParam &params, const Shape &input1_shape, const T *input1_data,
Expand Down Expand Up @@ -174,6 +188,34 @@ inline void BroadcastBinaryArithmeticOpSlow(
}
}

template <>
inline void BroadcastBinaryArithmeticOpSlow(
const BinaryArithmeticOpParam &, const Shape &input1_shape, const bool *input1_data,
const Shape &input2_shape, const bool *input2_data, const Shape &output_shape, bool *output_data,
const std::function<bool(const bool &, const bool &)> &fn)
{
NdArrayDesc<4> desc1;
NdArrayDesc<4> desc2;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1, &desc2);
const Shape extended_output_shape = Shape::ExtendedShape(4, output_shape);

for (int b = 0; b < extended_output_shape.Dims(0); ++b)
{
for (int y = 0; y < extended_output_shape.Dims(1); ++y)
{
for (int x = 0; x < extended_output_shape.Dims(2); ++x)
{
for (int c = 0; c < extended_output_shape.Dims(3); ++c)
{
output_data[Offset(extended_output_shape, b, y, x, c)] =
fn(input1_data[SubscriptToIndex(desc1, b, y, x, c)],
input2_data[SubscriptToIndex(desc2, b, y, x, c)]);
}
}
}
}
}

} // namespace reference
} // namespace cker
} // namespace nnfw
Expand Down
210 changes: 210 additions & 0 deletions compute/cker/src/Mul.test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
/*
* Copyright (c) 2023 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include <cker/operation/BinaryArithmeticOps.h>

#include <gtest/gtest.h>
#include <vector>

TEST(CKer_Operation, Mul)
{
// Simple
{
// Shape: {1, 2, 2, 1}
std::vector<int32_t> input1 = {10, 9, 11, 3};
// Shape: {1, 2, 2, 1}
std::vector<int32_t> input2 = {2, 2, 3, 4};
std::vector<int32_t> expected_output = {20, 18, 33, 12};
std::vector<int32_t> output(4);

nnfw::cker::BinaryArithmeticOpParam param;
param.quantized_activation_min = std::numeric_limits<int32_t>::lowest();
param.quantized_activation_max = std::numeric_limits<int32_t>::max();
nnfw::cker::Shape shape{1, 2, 2, 1};

nnfw::cker::BinaryArithmeticOp<nnfw::cker::BinaryArithmeticOpType::MUL>(
param, shape, input1.data(), shape, input2.data(), shape, output.data());

for (size_t i = 0; i < expected_output.size(); ++i)
EXPECT_EQ(output[i], expected_output[i]);
}

// Negative Value
{
// Shape: {1, 2, 2, 1}
std::vector<int32_t> input1 = {10, -9, -11, 7};
// Shape: {1, 2, 2, 1}
std::vector<int32_t> input2 = {2, 2, -3, -4};
std::vector<int32_t> expected_output = {20, -18, 33, -28};
std::vector<int32_t> output(4);

nnfw::cker::BinaryArithmeticOpParam param;
param.quantized_activation_min = std::numeric_limits<int32_t>::lowest();
param.quantized_activation_max = std::numeric_limits<int32_t>::max();
nnfw::cker::Shape shape{1, 2, 2, 1};

nnfw::cker::BinaryArithmeticOp<nnfw::cker::BinaryArithmeticOpType::MUL>(
param, shape, input1.data(), shape, input2.data(), shape, output.data());

for (size_t i = 0; i < expected_output.size(); ++i)
EXPECT_EQ(output[i], expected_output[i]);
}

// Broadcast
{
// Shape: {1, 2, 2, 1}
std::vector<int32_t> input1 = {10, -9, -11, 7};
// Shape: {1}
std::vector<int32_t> input2 = {-3};
std::vector<int32_t> expected_output = {-30, 27, 33, -21};
std::vector<int32_t> output(4);

nnfw::cker::BinaryArithmeticOpParam param;
param.broadcast_category = nnfw::cker::BroadcastableOpCategory::kGenericBroadcast;
param.quantized_activation_min = std::numeric_limits<int32_t>::lowest();
param.quantized_activation_max = std::numeric_limits<int32_t>::max();

nnfw::cker::BroadcastBinaryArithmeticOp<nnfw::cker::BinaryArithmeticOpType::MUL>(
param, nnfw::cker::Shape{1, 2, 2, 1}, input1.data(), nnfw::cker::Shape{1}, input2.data(),
nnfw::cker::Shape{1, 2, 2, 1}, output.data());

for (size_t i = 0; i < expected_output.size(); ++i)
EXPECT_EQ(output[i], expected_output[i]);
}

// Simple Float
{
// Shape: {1, 2, 2, 1}
std::vector<float> input1 = {10, 9, -11.1, 3};
// Shape: {1, 2, 2, 1}
std::vector<float> input2 = {2, -2.2, -3.3, 4};
std::vector<float> expected_output = {20, -19.8, 36.63, 12};
std::vector<float> output(4);

nnfw::cker::BinaryArithmeticOpParam param;
param.float_activation_min = std::numeric_limits<float>::lowest();
param.float_activation_max = std::numeric_limits<float>::max();
nnfw::cker::Shape shape{1, 2, 2, 1};

nnfw::cker::BinaryArithmeticOp<nnfw::cker::BinaryArithmeticOpType::MUL>(
param, shape, input1.data(), shape, input2.data(), shape, output.data());

for (size_t i = 0; i < expected_output.size(); ++i)
EXPECT_NEAR(output[i], expected_output[i], 1e-5f);
}

// Float Relu
{
// Shape: {1, 2, 2, 1}
std::vector<float> input1 = {10, 9, -11.1, 3};
// Shape: {1, 2, 2, 1}
std::vector<float> input2 = {2, -2.2, -3.3, 4};
std::vector<float> expected_output = {20, 0, 36.63, 12};
std::vector<float> output(4);

nnfw::cker::BinaryArithmeticOpParam param;
param.float_activation_min = 0;
param.float_activation_max = std::numeric_limits<float>::max();
nnfw::cker::Shape shape{1, 2, 2, 1};

nnfw::cker::BinaryArithmeticOp<nnfw::cker::BinaryArithmeticOpType::MUL>(
param, shape, input1.data(), shape, input2.data(), shape, output.data());

for (size_t i = 0; i < expected_output.size(); ++i)
EXPECT_NEAR(output[i], expected_output[i], 1e-5f);
}

// Broadcast
{
// Shape: {1, 2, 2, 1}
std::vector<float> input1 = {10, 9, -11.1, 3};
// Shape: {1}
std::vector<float> input2 = {-3};
std::vector<float> expected_output = {-30, -27, 33.3, -9};
std::vector<float> output(4);

nnfw::cker::BinaryArithmeticOpParam param;
param.broadcast_category = nnfw::cker::BroadcastableOpCategory::kGenericBroadcast;
param.float_activation_min = std::numeric_limits<float>::lowest();
param.float_activation_max = std::numeric_limits<float>::max();

nnfw::cker::BroadcastBinaryArithmeticOp<nnfw::cker::BinaryArithmeticOpType::MUL>(
param, nnfw::cker::Shape{1, 2, 2, 1}, input1.data(), nnfw::cker::Shape{1}, input2.data(),
nnfw::cker::Shape{1, 2, 2, 1}, output.data());

for (size_t i = 0; i < expected_output.size(); ++i)
EXPECT_NEAR(output[i], expected_output[i], 1e-5f);
}

// Bool8
{
// Shape: {1, 2, 2, 1}
bool input1[4] = {true, true, false, false};
// Shape: {1, 2, 2, 1}
bool input2[4] = {true, false, true, false};
bool expected_output[4] = {true, false, false, false};
bool output[4];

nnfw::cker::BinaryArithmeticOpParam param;
nnfw::cker::Shape shape{1, 2, 2, 1};

nnfw::cker::BinaryArithmeticOp<nnfw::cker::BinaryArithmeticOpType::MUL, bool>(
param, shape, input1, shape, input2, shape, output);

for (size_t i = 0; i < 4; ++i)
EXPECT_EQ(output[i], expected_output[i]);
}

// Broadcast Bool8
{
// Shape: {1, 2, 2, 1}
bool input1[4] = {true, true, false, false};
// Shape: {1, 2, 2, 1}
bool input2[1] = {true};
bool expected_output[4] = {true, true, false, false};
bool output[4];

nnfw::cker::BinaryArithmeticOpParam param;

nnfw::cker::BroadcastBinaryArithmeticOp<nnfw::cker::BinaryArithmeticOpType::MUL, bool>(
param, nnfw::cker::Shape{1, 2, 2, 1}, input1, nnfw::cker::Shape{1}, input2,
nnfw::cker::Shape{1, 2, 2, 1}, output);

for (size_t i = 0; i < 4; ++i)
EXPECT_EQ(output[i], expected_output[i]);
}

// TODO Add other types
}

TEST(CKer_Operation, neg_MulUnsupportedBroadcastRank)
{
// Unsupported rank
{
// Shape: {1, 2, 2, 1, 1}
std::vector<float> input1 = {10, -9, -11, 7};
// Shape: {1}
std::vector<float> input2 = {-3};
std::vector<float> output(4);

nnfw::cker::BinaryArithmeticOpParam param;

EXPECT_ANY_THROW(
nnfw::cker::BroadcastBinaryArithmeticOp<nnfw::cker::BinaryArithmeticOpType::MUL>(
param, nnfw::cker::Shape{1, 2, 2, 1, 1}, input1.data(), nnfw::cker::Shape{1}, input2.data(),
nnfw::cker::Shape{1, 2, 2, 1, 1}, output.data()));
}
}

0 comments on commit c338b1e

Please sign in to comment.