Skip to content

Commit

Permalink
[GPU/OpenCL] Initial version of Addition Layer with OpenCL ops
Browse files Browse the repository at this point in the history
Added naive version of OpenCL implementation for Addition Layer.
Incorporated kernel for ops used.
Added unit test for addition_layer_cl.

Signed-off-by: yash.singh <[email protected]>
  • Loading branch information
yashSingh0723 committed May 27, 2024
1 parent 3fe9a1e commit a8f089f
Show file tree
Hide file tree
Showing 10 changed files with 428 additions and 2 deletions.
11 changes: 11 additions & 0 deletions api/ccapi/include/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,17 @@ Addition(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_ADDITION, properties);
}

#ifdef ENABLE_OPENCL
/**
* @brief Helper function to create Addition layer for GPU
*/
inline std::unique_ptr<Layer>
AdditionCL(const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) {
return createLayer(LayerType::LAYER_ADDITION, properties, compute_engine);
}
#endif

/**
* @brief Helper function to create concat layer
*/
Expand Down
5 changes: 5 additions & 0 deletions nntrainer/cl_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
* creates the OpenCL command queue and context.
*/

#include <addition_layer_cl.h>
#include <cl_context.h>
#include <fc_layer_cl.h>

Expand All @@ -26,6 +27,10 @@ static void add_default_object(ClContext &cc) {
cc.registerFactory(nntrainer::createLayer<FullyConnectedLayerCl>,
FullyConnectedLayerCl::type,
ml::train::LayerType::LAYER_FC);

cc.registerFactory(nntrainer::createLayer<AdditionLayerCL>,
AdditionLayerCL::type,
ml::train::LayerType::LAYER_ADDITION);
}

static void registerer(ClContext &cc) noexcept {
Expand Down
210 changes: 210 additions & 0 deletions nntrainer/layers/cl_layers/addition_layer_cl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 Yash Singh <[email protected]>
*
* @file addition_layer_cl.cpp
* @date 17 May 2024
* @see https://github.com/nnstreamer/nntrainer
* @author Yash Singh [email protected]>
* @bug No known bugs except for NYI items
* @brief This is Addition Layer Class Class for Neural Network with OpenCl
* implementation
*/

#include <addition_layer_cl.h>
#include <nntrainer_error.h>
#include <nntrainer_log.h>
#include <node_exporter.h>
#include <util_func.h>

#include <layer_context.h>

std::string addition_cl_kernel_ =
R"(__kernel void addition_cl(__global const float* input, __global float* output, const unsigned int size) {
#pragma printf_support
size_t idx = get_global_id(0);
if (idx < size) {
output[idx] = output[idx] + input[idx];
}
})";

namespace nntrainer {

static constexpr size_t SINGLE_INOUT_IDX = 0;

void AdditionLayerCL::finalize(InitLayerContext &context) {
context.setOutputDimensions({context.getInputDimensions()[0]});
}

void AdditionLayerCL::forwarding(RunLayerContext &context, bool training) {
Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);

/** @todo check possibility for in-place of addition layer */
for (unsigned int idx = 0; idx < context.getNumInputs(); ++idx) {
const Tensor &input_ = context.getInput(idx);
if (!idx) {
hidden_.copy(input_);
} else {
// hidden_.add_i(input_);
AddProcess(input_, hidden_, context);
}
}
}

/**
* @brief declaring static kerinputnel objects
*
*/
opencl::Kernel AdditionLayerCL::kernel_addition;

void AdditionLayerCL::AddProcess(Tensor const &input, Tensor &result,
RunLayerContext &context) {

CREATE_IF_EMPTY_DIMS(result, result.getDim());

NNTR_THROW_IF(result.getData() == nullptr, std::invalid_argument)
<< result.getName() << " is not allocated";
NNTR_THROW_IF(input.getData() == nullptr, std::invalid_argument)
<< input.getName() << " is not allocated";

if (input.getDim() != result.getDim()) {
throw std::invalid_argument(
"Error: Dimensions does not match for addition");
}

if (input.getDataType() == ml::train::TensorDim::DataType::FP32) {
unsigned int size = input.size();
const float *data = input.getData();
float *rdata = result.getData();

addition_cl(data, rdata, size, context);

} else
throw std::invalid_argument("Error: OpenCL fp16 is not supported yet.");
}

void AdditionLayerCL::addition_cl(const float *input, float *res,
unsigned int size, RunLayerContext &context) {

bool result = false;
do {
result = result =
context.clCreateKernel(addition_cl_kernel_, context.LayerKernel::ADD,
AdditionLayerCL::kernel_addition);
if (!result) {
break;
}

size_t dim1_size = sizeof(float) * size;
opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr);

opencl::Buffer inOutRes(context.context_inst_, dim1_size, true, nullptr);

result = inputA.WriteData(context.command_queue_inst_, input);
if (!result) {
break;
}

result = inOutRes.WriteData(context.command_queue_inst_, res);
if (!result) {
break;
}

result = AdditionLayerCL::kernel_addition.SetKernelArguments(
0, &inputA, sizeof(cl_mem));
if (!result) {
break;
}

result = AdditionLayerCL::kernel_addition.SetKernelArguments(
1, &inOutRes, sizeof(cl_mem));
if (!result) {
break;
}

result = AdditionLayerCL::kernel_addition.SetKernelArguments(2, &size,
sizeof(int));
if (!result) {
break;
}

const int work_groups_count[3] = {(int)size, 1, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value
result = context.command_queue_inst_.DispatchCommand(
AdditionLayerCL::kernel_addition, work_groups_count, work_group_size);
if (!result) {
break;
}

result = inOutRes.ReadData(context.command_queue_inst_, res);
if (!result) {
break;
}

} while (false);
}

void AdditionLayerCL::incremental_forwarding(RunLayerContext &context,
unsigned int from, unsigned int to,
bool training) {
Tensor &hidden_ = context.getOutput(SINGLE_INOUT_IDX);
TensorDim hidden_dim = hidden_.getDim();
TensorDim hidden_step_dim = hidden_dim;

if (from) {
NNTR_THROW_IF(to - from != 1, std::invalid_argument)
<< "incremental step size is not 1";
from = 0;
to = 1;
}

hidden_step_dim.batch(1);
hidden_step_dim.height(to - from);

for (unsigned int b = 0; b < hidden_.batch(); ++b) {
Tensor hidden_step = hidden_.getSharedDataTensor(
hidden_step_dim, b * hidden_dim.getFeatureLen(), true);

/** @todo check possibility for in-place of addition layer */
for (unsigned int idx = 0; idx < context.getNumInputs(); ++idx) {
const Tensor &input_ = context.getInput(idx);
TensorDim input_dim = input_.getDim();

TensorDim input_step_dim = input_dim;
input_step_dim.batch(1);
input_step_dim.height(to - from);

Tensor input_step = input_.getSharedDataTensor(
input_step_dim, b * input_dim.getFeatureLen(), true);
if (!idx) {
hidden_step.copy(input_step);
} else {
// hidden_step.add_i(input_step);
AddProcess(input_step, hidden_step, context);
}
}
}
}

void AdditionLayerCL::calcDerivative(RunLayerContext &context) {

for (unsigned int idx = 0; idx < context.getNumInputs(); ++idx) {
/**
* TODO: replace this with tensor assignment during optimization.
* Tensor assignment needs to make sure that the previous connected layers
* are not inplace
*/
context.getOutgoingDerivative(idx).copy(
context.getIncomingDerivative(SINGLE_INOUT_IDX));
}
}

void AdditionLayerCL::setProperty(const std::vector<std::string> &values) {
auto remain_props = loadProperties(values, add_props);
if (!remain_props.empty()) {
std::string msg = "[AdditionLayer] Unknown Layer Properties count " +
std::to_string(values.size());
throw exception::not_supported(msg);
}
}
} /* namespace nntrainer */
136 changes: 136 additions & 0 deletions nntrainer/layers/cl_layers/addition_layer_cl.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
// SPDX-License-Identifier: Apache-2.0
/**
* Copyright (C) 2024 Yash Singh <[email protected]>
*
* @file addition_layer_cl.h
* @date 17 May 2024
* @see https://github.com/nnstreamer/nntrainer
* @author Yash Singh [email protected]>
* @bug No known bugs except for NYI items
* @brief This is Addition Layer Class Class for Neural Network with OpenCl
* implementation
*/

#ifndef __ADDITION_LAYER_CL_H__
#define __ADDITION_LAYER_CL_H__
#ifdef __cplusplus

#include <common_properties.h>
#include <layer_devel.h>
#include <opencl_buffer.h>
#include <opencl_kernel.h>

#define CREATE_IF_EMPTY_DIMS(tensor, ...) \
do { \
if (tensor.empty()) \
tensor = Tensor(__VA_ARGS__); \
} while (0);

namespace nntrainer {

/**
* @class AdditionLayerCL
* @brief Addition Layer
*/
class AdditionLayerCL : public Layer {
public:
/**
* @brief Constructor of Addition Layer
*/
AdditionLayerCL() : Layer(), add_props(props::Print()) {}

/**
* @brief Destructor of Addition Layer
*/
~AdditionLayerCL(){};

/**
* @brief Move constructor of AdditionLayer.
* @param[in] AdditionLayer &&
*/
AdditionLayerCL(AdditionLayerCL &&rhs) noexcept = default;

/**
* @brief Move assignment operator.
* @parma[in] rhs AdditionLayer to be moved.
*/
AdditionLayerCL &operator=(AdditionLayerCL &&rhs) = default;

/**
* @copydoc Layer::finalize(InitLayerContext &context)
*/
void finalize(InitLayerContext &context) override;

/**
* @copydoc Layer::forwarding(RunLayerContext &context, bool training)
*/
void forwarding(RunLayerContext &context, bool training) override;

/**
* @copydoc Layer::incremental_forwarding(RunLayerContext &context, unsigned
* int from, unsigned int to, bool training)
*/
void incremental_forwarding(RunLayerContext &context, unsigned int from,
unsigned int to, bool training) override;

/**
* @copydoc Layer::calcDerivative(RunLayerContext &context)
*/
void calcDerivative(RunLayerContext &context) override;

/**
* @brief declaring static kernel objects
*/
static opencl::Kernel kernel_addition;

/**
* @brief Process data and dimensions for add operation used in addition layer
* @param[in] input Tensor
* @param[in] result Tensor
* @param[in] RunLayerContext reference
*/
void AddProcess(Tensor const &input, Tensor &result,
RunLayerContext &context);

/**
* @brief addition : sum of all input vectors
* @param[in] input float * for input
* @param[in] res float * for result/output
* @param[in] size number of elements in input vector
* @param[in] context RunLayerContext reference
*/
void addition_cl(const float *input, float *res, unsigned int size,
RunLayerContext &context);

/**
* @copydoc bool supportBackwarding() const
*/
bool supportBackwarding() const override { return true; };

/**
* @copydoc Layer::exportTo(Exporter &exporter, ml::train::ExportMethods
* method)
*/
void exportTo(Exporter &exporter,
const ml::train::ExportMethods &method) const override {}

/**
* @copydoc Layer::setProperty(const std::vector<std::string> &values)
*/
void setProperty(const std::vector<std::string> &values) override;

/**
* @copydoc Layer::getType()
*/
const std::string getType() const override { return AdditionLayerCL::type; };

std::tuple<props::Print>
add_props; /**< fc layer properties : unit - number of output neurons */

inline static const std::string type = "addition";
};

} // namespace nntrainer

#endif /* __cplusplus */
#endif /* __ADDITION_LAYER_H__ */
3 changes: 2 additions & 1 deletion nntrainer/layers/cl_layers/meson.build
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
cl_layer_sources = [
'fc_layer_cl.cpp',
'blas_kernels.cpp'
'blas_kernels.cpp',
'addition_layer_cl.cpp'
]

foreach s : cl_layer_sources
Expand Down
Loading

0 comments on commit a8f089f

Please sign in to comment.