Skip to content

Commit

Permalink
[ GPU/OpenCL ] change reshape_cl to inherit LayerImplCl
Browse files Browse the repository at this point in the history
- This commit updates reshape_cl.cpp/.h to inherit LayerImplCl.
- This commit implements registerClKernels(), which is called in
context_cl.cpp

Self evaluation:

Build test: [X]Passed [ ]Failed [ ]Skipped
Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Eunju Yang <[email protected]>
  • Loading branch information
EunjuYang committed Nov 6, 2024
1 parent 0935dbc commit 4354f53
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 34 deletions.
24 changes: 14 additions & 10 deletions nntrainer/cl_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@ std::once_flag global_cl_context_init_flag;

static void add_default_object(ClContext &cc) {

FullyConnectedLayerCl::registerClKernels();
cc.registerFactory(nntrainer::createLayer<FullyConnectedLayerCl>,
FullyConnectedLayerCl::type,
ml::train::LayerType::LAYER_FC);
if (FullyConnectedLayerCl::registerClKernels()) {
cc.registerFactory(nntrainer::createLayer<FullyConnectedLayerCl>,
FullyConnectedLayerCl::type,
ml::train::LayerType::LAYER_FC);
}

// cc.registerFactory(nntrainer::createLayer<AdditionLayerCL>,
// AdditionLayerCL::type,
Expand All @@ -45,16 +46,19 @@ static void add_default_object(ClContext &cc) {
// SwiGLULayerCl::type,
// ml::train::LayerType::LAYER_SWIGLU);

ReshapeLayerCl::registerClKernels();
cc.registerFactory(nntrainer::createLayer<ReshapeLayerCl>,
ReshapeLayerCl::type, ml::train::LayerType::LAYER_RESHAPE);
if (ReshapeLayerCl::registerClKernels()) {
cc.registerFactory(nntrainer::createLayer<ReshapeLayerCl>,
ReshapeLayerCl::type,
ml::train::LayerType::LAYER_RESHAPE);
}

// cc.registerFactory(nntrainer::createLayer<RMSNormLayerCl>,
// RMSNormLayerCl::type, ml::train::LayerType::LAYER_RMSNORM);

ConcatLayerCl::registerClKernels();
cc.registerFactory(nntrainer::createLayer<ConcatLayerCl>, ConcatLayerCl::type,
ml::train::LayerType::LAYER_CONCAT);
if (ConcatLayerCl::registerClKernels()) {
cc.registerFactory(nntrainer::createLayer<ConcatLayerCl>,
ConcatLayerCl::type, ml::train::LayerType::LAYER_CONCAT);
}
}

static void registerer(ClContext &cc) noexcept {
Expand Down
33 changes: 20 additions & 13 deletions nntrainer/layers/cl_layers/reshape_cl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ namespace nntrainer {

static constexpr size_t SINGLE_INOUT_IDX = 0;

bool ReshapeLayerCl::registerClKernels() {

ClContext::SharedPtrClKernel kernel_copy_ptr = nullptr;

kernel_copy_ptr = cl_context_ref.registerClKernel(copy_cl_kernel_, "copy_cl");
NNTR_THROW_IF(!kernel_copy_ptr, std::runtime_error)
<< "OpenCL Error: Fail to register copy_cl kernel";
layer_kernel_ptrs.emplace_back(kernel_copy_ptr);

kernel_copy_ptr =
cl_context_ref.registerClKernel(copy_cl_kernel_fp16_, "copy_cl_fp16");
NNTR_THROW_IF(!kernel_copy_ptr, std::runtime_error)
<< "OpenCL Error: Fail to register copy_cl_fp16 kernel";
layer_kernel_ptrs.emplace_back(kernel_copy_ptr);

return true;
};

void ReshapeLayerCl::finalize(InitLayerContext &context) {
NNTR_THROW_IF(context.getNumInputs() != 1, std::invalid_argument)
<< "Reshape only supports 1 input for now";
Expand Down Expand Up @@ -98,9 +116,6 @@ void ReshapeLayerCl::incremental_forwarding(RunLayerContext &context,
}
}

opencl::Kernel ReshapeLayerCl::kernel_copy;
opencl::Kernel ReshapeLayerCl::kernel_copy_fp16;

void ReshapeLayerCl::ReshapeProcess(Tensor const &input, Tensor &output) {

unsigned int input_batch_size, input_height, input_width, input_channels;
Expand Down Expand Up @@ -136,11 +151,7 @@ void ReshapeLayerCl::copy_cl_fp16(const __fp16 *input, __fp16 *res,
bool result = false;

do {
ClContext::SharedPtrClKernel kernel_copy_ptr =
cl_context_ref.registerClKernel(copy_cl_kernel_fp16_, "copy_cl_fp16");
if (!kernel_copy_ptr) {
break;
}
const auto &kernel_copy_ptr = layer_kernel_ptrs[Kernels::COPY_CL];

size_t dim_size = sizeof(__fp16) * input_batch_size * input_height *
input_width * input_channels;
Expand Down Expand Up @@ -219,11 +230,7 @@ void ReshapeLayerCl::copy_cl(const float *input, float *res,
bool result = false;

do {
ClContext::SharedPtrClKernel kernel_copy_ptr =
cl_context_ref.registerClKernel(copy_cl_kernel_, "copy_cl");
if (!kernel_copy_ptr) {
break;
}
const auto &kernel_copy_ptr = layer_kernel_ptrs[Kernels::COPY_CL];

size_t dim_size = sizeof(float) * input_batch_size * input_height *
input_width * input_channels;
Expand Down
20 changes: 12 additions & 8 deletions nntrainer/layers/cl_layers/reshape_cl.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <cl_context.h>
#include <common_properties.h>
#include <layer_devel.h>
#include <layer_impl_cl.h>
#include <opencl_buffer.h>
#include <opencl_kernel.h>

Expand All @@ -26,10 +27,7 @@ namespace nntrainer {
* @class Reshape Layer
* @brief Reshape Layer
*/
class ReshapeLayerCl : public Layer {

private:
inline static ClContext cl_context_ref;
class ReshapeLayerCl : public LayerImplCl {

public:
/**
Expand Down Expand Up @@ -105,16 +103,18 @@ class ReshapeLayerCl : public Layer {

inline static const std::string type = "reshape";

static opencl::Kernel kernel_copy;
static opencl::Kernel kernel_copy_fp16;

/**
* @brief Process data and dimensions for reshape operation
* @param[in] input Tensor
* @param[in] result Tensor
*/
void ReshapeProcess(Tensor const &input, Tensor &result);

/**
* @brief registerClKernels
*/
static bool registerClKernels();

/**
* @brief copy computation
* @param[in] input float * for Input Tensor
Expand Down Expand Up @@ -145,9 +145,13 @@ class ReshapeLayerCl : public Layer {
unsigned int input_height, unsigned int input_width);
#endif

protected:
private:
std::tuple<props::TargetShape>
reshape_props; /**< reshape properties : target_shape after reshape */

inline static std::vector<ClContext::SharedPtrClKernel> layer_kernel_ptrs;

enum Kernels { COPY_CL, COPY_CL_FP16 };
};

} // namespace nntrainer
Expand Down
3 changes: 0 additions & 3 deletions test/jni/Android.mk
Original file line number Diff line number Diff line change
Expand Up @@ -443,13 +443,11 @@ LOCAL_SRC_FILES := \
../unittest/layers/unittest_layers.cpp \
../unittest/layers/unittest_layers_impl.cpp \
../unittest/layers/unittest_layers_concat_cl.cpp \
../unittest/layers/unittest_layers_swiglu_cl.cpp \
../unittest/layers/unittest_layers_fully_connected_cl.cpp \
../unittest/layers/unittest_layers_input.cpp \
../unittest/layers/unittest_layers_loss.cpp \
../unittest/layers/unittest_layers_reshape_cl.cpp \
../unittest/layers/unittest_layers_fully_connected.cpp \
../unittest/layers/unittest_layers_rmsnorm_cl.cpp \
../unittest/layers/unittest_layers_batch_normalization.cpp \
../unittest/layers/unittest_layers_layer_normalization.cpp \
../unittest/layers/unittest_layers_convolution2d.cpp \
Expand All @@ -458,7 +456,6 @@ LOCAL_SRC_FILES := \
../unittest/layers/unittest_layers_flatten.cpp \
../unittest/layers/unittest_layers_activation.cpp \
../unittest/layers/unittest_layers_addition.cpp \
../unittest/layers/unittest_layers_addition_cl.cpp \
../unittest/layers/unittest_layers_multiout.cpp \
../unittest/layers/unittest_layers_rnn.cpp \
../unittest/layers/unittest_layers_rnncell.cpp \
Expand Down

0 comments on commit 4354f53

Please sign in to comment.