Skip to content

Commit

Permalink
[onert-micro] Reduce ArgMax and ArgMin code duplication
Browse files Browse the repository at this point in the history
This pr reduces code duplication for ArgMin and ArgMax.

ONE-DCO-1.0-Signed-off-by: Artem Balyshev <[email protected]>
  • Loading branch information
Artem Balyshev committed Jun 26, 2024
1 parent 1d53f9f commit e3a7340
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 240 deletions.
44 changes: 44 additions & 0 deletions onert-micro/onert-micro/include/execute/kernels/ArgCommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_EXECUTE_KERNELS_ARG_COMMON_H
#define ONERT_MICRO_EXECUTE_KERNELS_ARG_COMMON_H

#include "OMStatus.h"

#include "core/OMUtils.h"
#include "core/OMKernelData.h"

#include "execute/OMKernelExecutionBuilder.h"
#include "execute/OMUtils.h"
#include "execute/OMRuntimeKernel.h"
#include <functional>

namespace onert_micro
{
namespace execute
{

OMStatus execute_arg_common(
const OMExecuteArgs &execute_args,
const std::function<OMStatus(const core::OMRuntimeShape &input1_shape, const float *input1_data,
const int *input2_data, const core::OMRuntimeShape &output_shape,
int *output_data)> &f_float);

} // namespace execute
} // namespace onert_micro

#endif // ONERT_MICRO_EXECUTE_KERNELS_ARG_COMMON_H
38 changes: 38 additions & 0 deletions onert-micro/onert-micro/include/import/helpers/OMArgCommon.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#ifndef ONERT_MICRO_IMPORT_HELPERS_CONFIGURE_ARG_KERNEL_COMMON_H
#define ONERT_MICRO_IMPORT_HELPERS_CONFIGURE_ARG_KERNEL_COMMON_H

#include "import/OMKernelConfigureBuilder.h"
#include "core/OMUtils.h"
#include "OMStatus.h"
#include "execute/OMRuntimeKernel.h"

namespace onert_micro
{
namespace import
{
namespace helpers
{

OMStatus configure_arg_kernel_common(const OMConfigureArgs &config_args);

} // namespace helpers
} // namespace import
} // namespace onert_micro

#endif // ONERT_MICRO_IMPORT_HELPERS_CONFIGURE_ARG_KERNEL_COMMON_H
1 change: 1 addition & 0 deletions onert-micro/onert-micro/src/execute/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ set(SOURCES
OMUtils.cpp
kernels/ConvolutionCommon.cpp
kernels/PoolingCommon.cpp
kernels/ArgCommon.cpp
kernels/ReshapeCommon.cpp
)

Expand Down
94 changes: 94 additions & 0 deletions onert-micro/onert-micro/src/execute/kernels/ArgCommon.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
/*
* Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "execute/kernels/ArgCommon.h"
#include "PALArgMax.h"

using namespace onert_micro;
using namespace onert_micro::execute;

namespace
{

constexpr uint32_t input1TensorIdx = 0;
constexpr uint32_t input2TensorIdx = 1;
constexpr uint32_t outputTensorIdx = 0;

} // namespace

OMStatus onert_micro::execute::execute_arg_common(
const OMExecuteArgs &execute_args,
const std::function<OMStatus(const core::OMRuntimeShape &input1_shape, const float *input1_data,
const int *input2_data, const core::OMRuntimeShape &output_shape,
int *output_data)> &f_float)
{
core::OMRuntimeContext &runtime_context = execute_args.runtime_context;
core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage;
uint16_t op_index = execute_args.kernel_index;
const circle::Tensor *output;
const circle::Tensor *input1;
const circle::Tensor *input2;

uint8_t *output_data;
uint8_t *input_data;
uint8_t *axis_data;

// Read kernel
execute::OMRuntimeKernel runtime_kernel;
runtime_kernel.readKernel(op_index, runtime_context);

output = runtime_kernel.outputs[outputTensorIdx];
assert(output != nullptr);

input1 = runtime_kernel.inputs[input1TensorIdx];
assert(input1 != nullptr);

input2 = runtime_kernel.inputs[input2TensorIdx];
assert(input2 != nullptr);

runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context);

output_data = runtime_kernel.outputs_data[outputTensorIdx];
assert(output_data != nullptr);

input_data = runtime_kernel.inputs_data[input1TensorIdx];
assert(input_data != nullptr);

axis_data = runtime_kernel.inputs_data[input2TensorIdx];
assert(axis_data != nullptr);

OMStatus status;
const core::OMRuntimeShape input1_shape(input1);
const core::OMRuntimeShape output_shape(output);
switch (input1->type())
{
#ifndef DIS_FLOAT
case circle::TensorType_FLOAT32:
{
status = f_float(input1_shape, reinterpret_cast<const float *>(input_data),
reinterpret_cast<const int *>(axis_data), output_shape,
reinterpret_cast<int *>(output_data));
}
break;
#endif // DIS_FLOAT
default:
{
status = UnsupportedType;
assert(false && "Unsupported type.");
}
}
return status;
}
78 changes: 9 additions & 69 deletions onert-micro/onert-micro/src/execute/kernels/ArgMax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,80 +14,20 @@
* limitations under the License.
*/

#include "execute/OMKernelExecutionBuilder.h"
#include "OMStatus.h"
#include "execute/OMRuntimeKernel.h"

#include "core/OMRuntimeShape.h"
#include "execute/kernels/ArgCommon.h"
#include "PALArgMax.h"

using namespace onert_micro;
using namespace onert_micro::execute;

namespace
{
constexpr uint32_t input1TensorIdx = 0;
constexpr uint32_t input2TensorIdx = 1;
constexpr uint32_t outputTensorIdx = 0;
} // namespace

OMStatus onert_micro::execute::execute_kernel_CircleArgMax(const OMExecuteArgs &execute_args)
{
core::OMRuntimeContext &runtime_context = execute_args.runtime_context;
core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage;
uint16_t op_index = execute_args.kernel_index;
const circle::Tensor *output;
const circle::Tensor *input1;
const circle::Tensor *input2;

uint8_t *output_data;
uint8_t *input_data;
uint8_t *axis_data;

// Read kernel
execute::OMRuntimeKernel runtime_kernel;
runtime_kernel.readKernel(op_index, runtime_context);

output = runtime_kernel.outputs[outputTensorIdx];
assert(output != nullptr);

input1 = runtime_kernel.inputs[input1TensorIdx];
assert(input1 != nullptr);

input2 = runtime_kernel.inputs[input2TensorIdx];
assert(input2 != nullptr);

runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context);

output_data = runtime_kernel.outputs_data[outputTensorIdx];
assert(output_data != nullptr);

input_data = runtime_kernel.inputs_data[input1TensorIdx];
assert(input_data != nullptr);

axis_data = runtime_kernel.inputs_data[input2TensorIdx];
assert(axis_data != nullptr);

OMStatus status;
const core::OMRuntimeShape input1_shape(input1);
const core::OMRuntimeShape output_shape(output);
switch (input1->type())
{
#ifndef DIS_FLOAT
case circle::TensorType_FLOAT32:
{
status =
onert_micro::execute::pal::ArgMax(input1_shape, reinterpret_cast<const float *>(input_data),
reinterpret_cast<const int *>(axis_data), output_shape,
reinterpret_cast<int *>(output_data));
}
break;
#endif // DIS_FLOAT
default:
{
status = UnsupportedType;
assert(false && "Unsupported type.");
}
}
return status;
auto arg_max_float_lambda = [](const core::OMRuntimeShape &input1_shape, const float *input1_data,
const int *input2_data, const core::OMRuntimeShape &output_shape,
int *output_data) {
return onert_micro::execute::pal::ArgMax(input1_shape, input1_data, input2_data, output_shape,
output_data);
};

return execute_arg_common(execute_args, arg_max_float_lambda);
}
78 changes: 9 additions & 69 deletions onert-micro/onert-micro/src/execute/kernels/ArgMin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,80 +14,20 @@
* limitations under the License.
*/

#include "execute/OMKernelExecutionBuilder.h"
#include "OMStatus.h"
#include "execute/OMRuntimeKernel.h"

#include "core/OMRuntimeShape.h"
#include "execute/kernels/ArgCommon.h"
#include "PALArgMin.h"

using namespace onert_micro;
using namespace onert_micro::execute;

namespace
{
constexpr uint32_t input1TensorIdx = 0;
constexpr uint32_t input2TensorIdx = 1;
constexpr uint32_t outputTensorIdx = 0;
} // namespace

OMStatus onert_micro::execute::execute_kernel_CircleArgMin(const OMExecuteArgs &execute_args)
{
core::OMRuntimeContext &runtime_context = execute_args.runtime_context;
core::OMRuntimeStorage &runtime_storage = execute_args.runtime_storage;
uint16_t op_index = execute_args.kernel_index;
const circle::Tensor *output;
const circle::Tensor *input1;
const circle::Tensor *input2;

uint8_t *output_data;
uint8_t *input_data;
uint8_t *axis_data;

// Read kernel
execute::OMRuntimeKernel runtime_kernel;
runtime_kernel.readKernel(op_index, runtime_context);

output = runtime_kernel.outputs[outputTensorIdx];
assert(output != nullptr);

input1 = runtime_kernel.inputs[input1TensorIdx];
assert(input1 != nullptr);

input2 = runtime_kernel.inputs[input2TensorIdx];
assert(input2 != nullptr);

runtime_kernel.getDataFromStorage(op_index, runtime_storage, runtime_context);

output_data = runtime_kernel.outputs_data[outputTensorIdx];
assert(output_data != nullptr);

input_data = runtime_kernel.inputs_data[input1TensorIdx];
assert(input_data != nullptr);

axis_data = runtime_kernel.inputs_data[input2TensorIdx];
assert(axis_data != nullptr);

OMStatus status;
const core::OMRuntimeShape input1_shape(input1);
const core::OMRuntimeShape output_shape(output);
switch (input1->type())
{
#ifndef DIS_FLOAT
case circle::TensorType_FLOAT32:
{
status =
onert_micro::execute::pal::ArgMin(input1_shape, reinterpret_cast<const float *>(input_data),
reinterpret_cast<const int *>(axis_data), output_shape,
reinterpret_cast<int *>(output_data));
}
break;
#endif // DIS_FLOAT
default:
{
status = UnsupportedType;
assert(false && "Unsupported type.");
}
}
return status;
auto arg_max_float_lambda = [](const core::OMRuntimeShape &input1_shape, const float *input1_data,
const int *input2_data, const core::OMRuntimeShape &output_shape,
int *output_data) {
return onert_micro::execute::pal::ArgMin(input1_shape, input1_data, input2_data, output_shape,
output_data);
};

return execute_arg_common(execute_args, arg_max_float_lambda);
}
1 change: 1 addition & 0 deletions onert-micro/onert-micro/src/import/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ set(SOURCES
helpers/OMPadCommon.cpp
helpers/OMConfigureTISOKernel.cpp
helpers/OMPoolingCommon.cpp
helpers/OMArgCommon.cpp
)

# Add configure kernels
Expand Down
Loading

0 comments on commit e3a7340

Please sign in to comment.