Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU/OpenCL] Initial version of SwiGLU Layer with OpenCL ops #2624

Merged
merged 1 commit into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions api/ccapi/include/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* @see https://github.com/nnstreamer/nntrainer
* @author Parichay Kapoor <[email protected]>
* @author Debadri Samaddar <[email protected]>
* @author Niket Agarwal <[email protected]>
* @bug No known bugs except for NYI items
* @brief This is layers interface for c++ API
*
Expand Down Expand Up @@ -34,9 +35,10 @@ namespace train {
* @brief Enumeration of layer type
*/
enum LayerType {
LAYER_IN = ML_TRAIN_LAYER_TYPE_INPUT, /**< Input Layer type */
LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */
LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */
LAYER_IN = ML_TRAIN_LAYER_TYPE_INPUT, /**< Input Layer type */
LAYER_FC = ML_TRAIN_LAYER_TYPE_FC, /**< Fully Connected Layer type */
LAYER_SWIGLU = ML_TRAIN_LAYER_TYPE_SWIGLU, /**< Swiglu Layer type */
LAYER_BN = ML_TRAIN_LAYER_TYPE_BN, /**< Batch Normalization Layer type */
LAYER_CONV2D = ML_TRAIN_LAYER_TYPE_CONV2D, /**< Convolution 2D Layer type */
LAYER_POOLING2D = ML_TRAIN_LAYER_TYPE_POOLING2D, /**< Pooling 2D Layer type */
LAYER_FLATTEN = ML_TRAIN_LAYER_TYPE_FLATTEN, /**< Flatten Layer type */
Expand Down Expand Up @@ -295,6 +297,15 @@ inline std::unique_ptr<Layer> FullyConnected(
return createLayer(LayerType::LAYER_FC, properties, compute_engine);
}

/**
* @brief Helper function to create Swiglu layer
*/
inline std::unique_ptr<Layer>
Swiglu(const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) {
return createLayer(LayerType::LAYER_SWIGLU, properties, compute_engine);
}

/**
* @brief Helper function to create batch normalization layer
*/
Expand Down
1 change: 1 addition & 0 deletions api/nntrainer-api-common.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ typedef enum {
ML_TRAIN_LAYER_TYPE_POSITIONAL_ENCODING =
28, /**< Positional Encoding Layer type (Since 7.0) */
ML_TRAIN_LAYER_TYPE_IDENTITY = 29, /**< Identity Layer type (Since 8.0) */
ML_TRAIN_LAYER_TYPE_SWIGLU = 30, /**< Swiglu Layer type */
ML_TRAIN_LAYER_TYPE_PREPROCESS_FLIP =
300, /**< Preprocess flip Layer (Since 6.5) */
ML_TRAIN_LAYER_TYPE_PREPROCESS_TRANSLATE =
Expand Down
5 changes: 5 additions & 0 deletions nntrainer/cl_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
* @date 23 Feb 2024
* @see https://github.com/nnstreamer/nntrainer
* @author Debadri Samaddar <[email protected]>
* @author Niket Agarwal <[email protected]>
* @bug No known bugs except for NYI items
* @brief This file contains app context related functions and classes that
* manages the global configuration of the current OpenCL environment. It also
Expand All @@ -15,6 +16,7 @@
#include <addition_layer_cl.h>
#include <cl_context.h>
#include <fc_layer_cl.h>
#include <swiglu_cl.h>

namespace nntrainer {

Expand All @@ -31,6 +33,9 @@ static void add_default_object(ClContext &cc) {
cc.registerFactory(nntrainer::createLayer<AdditionLayerCL>,
AdditionLayerCL::type,
ml::train::LayerType::LAYER_ADDITION);

cc.registerFactory(nntrainer::createLayer<SwiGLULayerCl>, SwiGLULayerCl::type,
ml::train::LayerType::LAYER_SWIGLU);
}

static void registerer(ClContext &cc) noexcept {
Expand Down
1 change: 1 addition & 0 deletions nntrainer/layers/cl_layers/meson.build
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
cl_layer_sources = [
'fc_layer_cl.cpp',
'addition_layer_cl.cpp',
'swiglu_cl.cpp',
]

foreach s : cl_layer_sources
Expand Down
272 changes: 272 additions & 0 deletions nntrainer/layers/cl_layers/swiglu_cl.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,272 @@
// SPDX-License-Identifier: Apache-2.0
/**
*
* @file swiglu_cl.cpp
* @date 6th June 2024
* @brief Implementation of SwiGLU activation function
* @see https://github.com/nnstreamer/nntrainer
* @author Niket Agarwal <[email protected]>
* @bug No known bugs except for NYI items
*
*/

#include "swiglu_cl.h"
#include <iostream>

std::string swiglu_cl_kernel_fp16_ =
R"(
#pragma OPENCL EXTENSION cl_khr_fp16 : enable
__kernel void swiglu_cl_fp16(__global const half *in1, __global const half *in2, __global half *out) {
int i = get_global_id(0);
half swish = in1[i] * exp(in1[i]) / (1 + exp(in1[i]));
out[i] = swish * in2[i];
})";

std::string swiglu_cl_kernel_ =
R"(__kernel void swiglu_cl(__global const float *in1, __global const float *in2, __global float *out) {
int i = get_global_id(0);
float swish = in1[i] * exp(in1[i]) / (1 + exp(in1[i]));
out[i] = swish * in2[i];
})";

namespace nntrainer {

static constexpr size_t OUT_IDX = 0;
static constexpr size_t INPUT_IDX_1 = 0;
static constexpr size_t INPUT_IDX_2 = 1;

void SwiGLULayerCl::finalize(nntrainer::InitLayerContext &context) {
context.setOutputDimensions({context.getInputDimensions()[0]});
}

void SwiGLULayerCl::forwarding(RunLayerContext &context, bool training) {
Tensor &in1 = context.getInput(INPUT_IDX_1);
Tensor &in2 = context.getInput(INPUT_IDX_2);
Tensor &out = context.getOutput(OUT_IDX);
swigluProcess(in1, in2, out, context);
}

void SwiGLULayerCl::incremental_forwarding(RunLayerContext &context,
unsigned int from, unsigned int to,
bool training) {
Tensor &in1 = context.getInput(INPUT_IDX_1);
Tensor &in2 = context.getInput(INPUT_IDX_2);
Tensor &out = context.getOutput(OUT_IDX);

if (from) {
NNTR_THROW_IF(to - from != 1, std::invalid_argument)
<< "incremental step size is not 1";
from = 0;
to = 1;
}

swigluProcess(in1, in2, out, context);
}

opencl::Kernel SwiGLULayerCl::kernel_swiglu;
opencl::Kernel SwiGLULayerCl::kernel_swiglu_fp16;

void SwiGLULayerCl::swigluProcess(Tensor const &in1, Tensor const &in2,
Tensor &result, RunLayerContext &context) {

unsigned int dim1, dim2;
dim1 = in1.batch() * in1.channel() * in1.height();
dim2 = in1.width();

if (in1.getDataType() == ml::train::TensorDim::DataType::FP32) {
const float *data1 = in1.getData();
const float *data2 = in2.getData();
float *rdata = result.getData();
swiglu_cl(data1, data2, rdata, dim1, dim2, context);
} else if (in1.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
const _FP16 *data1 = in1.getData<_FP16>();
const _FP16 *data2 = in2.getData<_FP16>();
_FP16 *rdata = result.getData<_FP16>();
swiglu_cl_fp16(data1, data2, rdata, dim1, dim2, context);
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
}
}

void SwiGLULayerCl::swiglu_cl(const float *matAdata, const float *vecXdata,
float *vecYdata, unsigned int dim1,
unsigned int dim2, RunLayerContext &context) {

bool result = false;

do {
result =
context.clCreateKernel(swiglu_cl_kernel_, context.LayerKernel::SWIGLU,
SwiGLULayerCl::kernel_swiglu);
if (!result) {
break;
}

int dim = int(dim1 * dim2);
opencl::Buffer inputA(context.context_inst_, sizeof(float) * dim1 * dim2, true,
nullptr);

opencl::Buffer inputX(context.context_inst_, sizeof(float) * dim1 * dim2, true,
nullptr);

opencl::Buffer inOutY(context.context_inst_, sizeof(float) * dim1 * dim2, true,
nullptr);

result = inputA.WriteData(context.command_queue_inst_, matAdata);
if (!result) {
break;
}

result = inputX.WriteData(context.command_queue_inst_, vecXdata);
if (!result) {
break;
}

result = inOutY.WriteData(context.command_queue_inst_, vecYdata);
if (!result) {
break;
}

result = SwiGLULayerCl::kernel_swiglu.SetKernelArguments(0, &inputA,
sizeof(cl_mem));
if (!result) {
break;
}

result = SwiGLULayerCl::kernel_swiglu.SetKernelArguments(1, &inputX,
sizeof(cl_mem));
if (!result) {
break;
}

result = SwiGLULayerCl::kernel_swiglu.SetKernelArguments(2, &inOutY,
sizeof(cl_mem));
if (!result) {
break;
}

const int work_groups_count[3] = {dim, 1, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value

result = context.command_queue_inst_.DispatchCommand(
SwiGLULayerCl::kernel_swiglu, work_groups_count, work_group_size);
if (!result) {
break;
}

result = inOutY.ReadData(context.command_queue_inst_, vecYdata);
if (!result) {
break;
}

} while (false);
}

void SwiGLULayerCl::swiglu_cl_fp16(const __fp16 *matAdata,
const __fp16 *vecXdata, __fp16 *vecYdata,
unsigned int dim1, unsigned int dim2,
RunLayerContext &context) {

bool result = false;

do {
result = context.clCreateKernel(swiglu_cl_kernel_fp16_,
context.LayerKernel::SWIGLU_FP16,
SwiGLULayerCl::kernel_swiglu_fp16);
if (!result) {
break;
}

int dim = int(dim1 * dim2);
opencl::Buffer inputA(context.context_inst_, sizeof(__fp16) * dim1 * dim2, true,
nullptr);

opencl::Buffer inputX(context.context_inst_, sizeof(__fp16) * dim1 * dim2, true,
nullptr);

opencl::Buffer inOutY(context.context_inst_, sizeof(__fp16) * dim1 * dim2, true,
nullptr);

result = inputA.WriteData(context.command_queue_inst_, matAdata);
if (!result) {
break;
}

result = inputX.WriteData(context.command_queue_inst_, vecXdata);
if (!result) {
break;
}

result = inOutY.WriteData(context.command_queue_inst_, vecYdata);
if (!result) {
break;
}

result = SwiGLULayerCl::kernel_swiglu_fp16.SetKernelArguments(
0, &inputA, sizeof(cl_mem));
if (!result) {
break;
}

result = SwiGLULayerCl::kernel_swiglu_fp16.SetKernelArguments(
1, &inputX, sizeof(cl_mem));
if (!result) {
break;
}

result = SwiGLULayerCl::kernel_swiglu_fp16.SetKernelArguments(
2, &inOutY, sizeof(cl_mem));
if (!result) {
break;
}

const int work_groups_count[3] = {dim, 1, 1};
const int work_group_size[3] = {32, 32, 1}; // test-value

result = context.command_queue_inst_.DispatchCommand(
SwiGLULayerCl::kernel_swiglu_fp16, work_groups_count, work_group_size);
if (!result) {
break;
}

result = inOutY.ReadData(context.command_queue_inst_, vecYdata);
if (!result) {
break;
}

} while (false);
}

void SwiGLULayerCl::calcDerivative(nntrainer::RunLayerContext &context) {
std::throw_with_nested(std::runtime_error("Training is not supported yet."));
}

void SwiGLULayerCl::setProperty(const std::vector<std::string> &values) {
auto remain_props = loadProperties(values, swiglu_props);
if (!remain_props.empty()) {
std::string msg = "[SwigluLayerCl] Unknown Layer Properties count " +
std::to_string(values.size());
throw exception::not_supported(msg);
}
}

#ifdef PLUGGABLE

Layer *create_swiglu_layer_cl() {
auto layer = new SwiGLULayerCl();
return layer;
}

void destroy_swiglu_layer_cl(Layer *layer) {
delete layer;
}

extern "C" {
LayerPluggable ml_train_layer_pluggable{create_swiglu_layer_cl,
destroy_swiglu_layer_cl};
}

#endif
} // namespace nntrainer
Loading