Skip to content

Commit

Permalink
[GPU/OpenCL] Initial version of FC Layer with OpenCL ops
Browse files Browse the repository at this point in the history
Added naive version of OpenCl implementation for FC Layer.
Incorporated separate kernels for ops used.
Added unit test for fc_layer_cl.

Signed-off-by: Debadri Samaddar <[email protected]>
  • Loading branch information
s-debadri committed May 7, 2024
1 parent f9a4cd4 commit f49a75e
Show file tree
Hide file tree
Showing 10 changed files with 886 additions and 10 deletions.
12 changes: 11 additions & 1 deletion api/ccapi/include/layer.h
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,21 @@ Input(const std::vector<std::string> &properties = {}) {
/**
* @brief Helper function to create fully connected layer
*/
inline std::unique_ptr<Layer> FullyConnected(
inline std::unique_ptr<Layer>
FullyConnected(const std::vector<std::string> &properties = {}) {
return createLayer(LayerType::LAYER_FC, properties);
}

#ifdef ENABLE_OPENCL
/**
* @brief Helper function to create fully connected layer for GPU
*/
inline std::unique_ptr<Layer> FullyConnectedCl(
const std::vector<std::string> &properties = {},
const LayerComputeEngine &compute_engine = LayerComputeEngine::CPU) {
return createLayer(LayerType::LAYER_FC, properties, compute_engine);
}
#endif

/**
* @brief Helper function to create batch normalization layer
Expand Down
7 changes: 4 additions & 3 deletions nntrainer/cl_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
*/

#include <cl_context.h>
#include <fc_layer.h>
#include <fc_layer_cl.h>

namespace nntrainer {

Expand All @@ -23,8 +23,9 @@ std::once_flag global_cl_context_init_flag;

static void add_default_object(ClContext &cc) {

cc.registerFactory(nntrainer::createLayer<FullyConnectedLayer>,
FullyConnectedLayer::type, ml::train::LayerType::LAYER_FC);
cc.registerFactory(nntrainer::createLayer<FullyConnectedLayerCl>,
FullyConnectedLayerCl::type,
ml::train::LayerType::LAYER_FC);
}

static void registerer(ClContext &cc) noexcept {
Expand Down
Loading

0 comments on commit f49a75e

Please sign in to comment.