Skip to content

Commit

Permalink
[GPU/OpenCL] Updated the SwiGLU, Reshape and Concat Layers
Browse files Browse the repository at this point in the history
Updated the swiglu, reshape, and concat layers with the new shared_ptr flow.
Replaced clCreateKernel with registerClKernel for all these layers.

Self evaluation:

	Build test: [X]Passed [ ]Failed [ ]Skipped
        Run test: [X]Passed [ ]Failed [ ]Skipped

Signed-off-by: Niket Agarwal <[email protected]>
  • Loading branch information
niket-agarwal committed Oct 14, 2024
1 parent cb5afd8 commit a6fd820
Show file tree
Hide file tree
Showing 8 changed files with 369 additions and 411 deletions.
15 changes: 6 additions & 9 deletions nntrainer/cl_context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,21 +39,18 @@ static void add_default_object(ClContext &cc) {
// AdditionLayerCL::type,
// ml::train::LayerType::LAYER_ADDITION);

// cc.registerFactory(nntrainer::createLayer<SwiGLULayerCl>,
// SwiGLULayerCl::type,
// ml::train::LayerType::LAYER_SWIGLU);
cc.registerFactory(nntrainer::createLayer<SwiGLULayerCl>, SwiGLULayerCl::type,
ml::train::LayerType::LAYER_SWIGLU);

// cc.registerFactory(nntrainer::createLayer<ReshapeLayerCl>,
// ReshapeLayerCl::type,
// ml::train::LayerType::LAYER_RESHAPE);
cc.registerFactory(nntrainer::createLayer<ReshapeLayerCl>,
ReshapeLayerCl::type, ml::train::LayerType::LAYER_RESHAPE);

// cc.registerFactory(nntrainer::createLayer<RMSNormLayerCl>,
// RMSNormLayerCl::type,
// ml::train::LayerType::LAYER_RMSNORM);

// cc.registerFactory(nntrainer::createLayer<ConcatLayerCl>,
// ConcatLayerCl::type,
// ml::train::LayerType::LAYER_CONCAT);
cc.registerFactory(nntrainer::createLayer<ConcatLayerCl>, ConcatLayerCl::type,
ml::train::LayerType::LAYER_CONCAT);
}

static void registerer(ClContext &cc) noexcept {
Expand Down
Loading

0 comments on commit a6fd820

Please sign in to comment.