diff --git a/nntrainer/layers/cl_layers/addition_layer_cl.cpp b/nntrainer/layers/cl_layers/addition_layer_cl.cpp
index 1cd9f1de41..7cb9a31068 100644
--- a/nntrainer/layers/cl_layers/addition_layer_cl.cpp
+++ b/nntrainer/layers/cl_layers/addition_layer_cl.cpp
@@ -11,8 +11,8 @@
  * implementation
  */
 
-#include <blas_kernels.h>
 #include <addition_layer_cl.h>
+#include <blas_kernels.h>
 #include <nntrainer_error.h>
 #include <nntrainer_log.h>
 #include <node_exporter.h>
@@ -64,8 +64,18 @@ void AdditionLayerCL::AddProcess(Tensor const &input, Tensor &result,
 
     addition_cl(data, rdata, size, context);
 
-  } else
-    throw std::invalid_argument("Error: OpenCL fp16 is not supported yet.");
+  } else if (input.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+    unsigned int size = input.size();
+    const _FP16 *data = input.getData<_FP16>();
+    _FP16 *rdata = result.getData<_FP16>();
+
+    addition_cl(data, rdata, size, context);
+
+#else
+    throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+  }
 }
 
 void AdditionLayerCL::incremental_forwarding(RunLayerContext &context,
diff --git a/nntrainer/layers/cl_layers/meson.build b/nntrainer/layers/cl_layers/meson.build
index 349e1f443d..f28b56cd55 100644
--- a/nntrainer/layers/cl_layers/meson.build
+++ b/nntrainer/layers/cl_layers/meson.build
@@ -1,7 +1,6 @@
 cl_layer_sources = [
   'fc_layer_cl.cpp',
-  'blas_kernels.cpp',
-  'addition_layer_cl.cpp'
+  'addition_layer_cl.cpp',
 ]
 
 foreach s : cl_layer_sources
diff --git a/nntrainer/layers/layer_context.cpp b/nntrainer/layers/layer_context.cpp
index a6615b92aa..798f8a1a5e 100644
--- a/nntrainer/layers/layer_context.cpp
+++ b/nntrainer/layers/layer_context.cpp
@@ -692,6 +692,8 @@ std::string RunLayerContext::getKernelName(LayerKernel layerKernel) {
     return "sgemm_cl_fp16";
   case LayerKernel::ADD:
     return "addition_cl";
+  case LayerKernel::ADD_FP16:
+    return "addition_cl_fp16";
   default:
     return "";
   }
diff --git a/nntrainer/layers/layer_context.h b/nntrainer/layers/layer_context.h
index 105725a57b..842789f6eb 100644
--- a/nntrainer/layers/layer_context.h
+++ b/nntrainer/layers/layer_context.h
@@ -835,7 +835,8 @@ class RunLayerContext {
     SGEMV_FP16 = 1 << 3, /**< placeholder for kernel name */
     DOT_FP16 = 1 << 4,   /**< placeholder for kernel name */
     SGEMM_FP16 = 1 << 5, /**< placeholder for kernel name */
-    ADD = 1 << 6         /**< placeholder for kernel name */
+    ADD = 1 << 6,         /**< placeholder for kernel name */
+    ADD_FP16 = 1 << 7    /**< placeholder for kernel name */
   };
 
   /**
diff --git a/nntrainer/tensor/cl_operations/blas_kernels.cpp b/nntrainer/tensor/cl_operations/blas_kernels.cpp
index 1b426137ec..1b0ff60987 100644
--- a/nntrainer/tensor/cl_operations/blas_kernels.cpp
+++ b/nntrainer/tensor/cl_operations/blas_kernels.cpp
@@ -309,15 +309,14 @@ void sgemm_cl(const float *A, const float *B, float *C, unsigned int M,
   } while (false);
 }
 
-void addition_cl(const float *input, float *res,
-                                  unsigned int size, RunLayerContext &context) {
+void addition_cl(const float *input, float *res, unsigned int size,
+                 RunLayerContext &context) {
 
   bool result = false;
-  
+
   do {
-    result = result =
-      context.clCreateKernel(addition_cl_kernel_, context.LayerKernel::ADD,
-                             kernel_addition);
+    result = context.clCreateKernel(addition_cl_kernel_,
+                                    context.LayerKernel::ADD, kernel_addition);
     if (!result) {
       break;
     }
diff --git a/nntrainer/tensor/cl_operations/blas_kernels.h b/nntrainer/tensor/cl_operations/blas_kernels.h
index 816c8ac913..6f7369e666 100644
--- a/nntrainer/tensor/cl_operations/blas_kernels.h
+++ b/nntrainer/tensor/cl_operations/blas_kernels.h
@@ -29,6 +29,7 @@ extern opencl::Kernel kernel_sgemm;
 extern opencl::Kernel kernel_dot;
 extern opencl::Kernel kernel_dot_fp16;
 extern opencl::Kernel kernel_addition;
+extern opencl::Kernel kernel_addition_fp16;
 
 /**
  * @brief     sgemv computation : Y = A*X + Y
@@ -135,5 +136,15 @@ void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
 void addition_cl(const float *input, float *res, unsigned int size,
                 RunLayerContext &context);
 
+/**
+ * @brief     fp16 addition : sum of all input vectors
+ * @param[in] input fp16 * for input
+ * @param[in] res fp16 * for result/output
+ * @param[in] size number of elements in input vector
+ * @param[in] context RunLayerContext reference
+ */
+void addition_cl(const __fp16 *input, __fp16 *res, unsigned int size,
+                RunLayerContext &context);
+
 } // namespace nntrainer
 #endif /* __BLAS_KERNELS_H__ */
diff --git a/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp b/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp
index 8948f0dc5c..b25e150b9a 100644
--- a/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp
+++ b/nntrainer/tensor/cl_operations/blas_kernels_fp16.cpp
@@ -60,12 +60,24 @@ std::string sgemm_cl_kernel_fp16_ =
         C[m * ldc + n] = c;
     })";
 
+std::string addition_cl_kernel_fp16_ =
+  R"(
+    #pragma OPENCL EXTENSION cl_khr_fp16 : enable
+
+    __kernel void addition_cl_fp16(__global const half* input, __global half* output, const unsigned int size) {
+    size_t idx = get_global_id(0);
+    if (idx < size) {
+        output[idx] = output[idx] + input[idx];
+    }
+  })";
+
 /**
  * @brief defining global kernel objects
  */
 opencl::Kernel kernel_sgemv_fp16;
 opencl::Kernel kernel_sgemm_fp16;
 opencl::Kernel kernel_dot_fp16;
+opencl::Kernel kernel_addition_fp16;
 
 void sgemv_cl(const __fp16 *matAdata, const __fp16 *vecXdata, __fp16 *vecYdata,
               unsigned int dim1, unsigned int dim2, unsigned int lda,
@@ -309,4 +321,65 @@ void sgemm_cl(const __fp16 *A, const __fp16 *B, __fp16 *C, unsigned int M,
 
   } while (false);
 }
+
+void addition_cl(const __fp16 *input, __fp16 *res, unsigned int size,
+                 RunLayerContext &context) {
+
+  bool result = false;
+
+  do {
+    result = context.clCreateKernel(addition_cl_kernel_fp16_,
+                                    context.LayerKernel::ADD_FP16,
+                                    kernel_addition_fp16);
+    if (!result) {
+      break;
+    }
+
+    size_t dim1_size = sizeof(cl_half) * size;
+    opencl::Buffer inputA(context.context_inst_, dim1_size, true, nullptr);
+
+    opencl::Buffer inOutRes(context.context_inst_, dim1_size, true, nullptr);
+
+    result = inputA.WriteData(context.command_queue_inst_, input);
+    if (!result) {
+      break;
+    }
+
+    result = inOutRes.WriteData(context.command_queue_inst_, res);
+    if (!result) {
+      break;
+    }
+
+    result =
+      kernel_addition_fp16.SetKernelArguments(0, &inputA, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result =
+      kernel_addition_fp16.SetKernelArguments(1, &inOutRes, sizeof(cl_mem));
+    if (!result) {
+      break;
+    }
+
+    result = kernel_addition_fp16.SetKernelArguments(2, &size, sizeof(int));
+    if (!result) {
+      break;
+    }
+
+    const int work_groups_count[3] = {(int)size, 1, 1};
+    const int work_group_size[3] = {32, 32, 1}; // test-value
+    result = context.command_queue_inst_.DispatchCommand(
+      kernel_addition_fp16, work_groups_count, work_group_size);
+    if (!result) {
+      break;
+    }
+
+    result = inOutRes.ReadData(context.command_queue_inst_, res);
+    if (!result) {
+      break;
+    }
+
+  } while (false);
+}
 } // namespace nntrainer
diff --git a/test/input_gen/gen_layer_tests.py b/test/input_gen/gen_layer_tests.py
index 5c20c7b10d..cf8e713983 100644
--- a/test/input_gen/gen_layer_tests.py
+++ b/test/input_gen/gen_layer_tests.py
@@ -889,6 +889,3 @@ def swiglu(inputs):
     
     added = K.layers.Add()
     record_single(added, [(3, 4, 3, 4), (3, 4, 3, 4)], "added_w32a32_2")
-    
-    added = K.layers.Add()
-    record_single(added, [(20, 55, 50, 55), (20, 55, 50, 55)], "added_w32a32_3")
diff --git a/test/unittest/layers/unittest_layers_addition_cl.cpp b/test/unittest/layers/unittest_layers_addition_cl.cpp
index a5d6907582..e7feaaaa50 100644
--- a/test/unittest/layers/unittest_layers_addition_cl.cpp
+++ b/test/unittest/layers/unittest_layers_addition_cl.cpp
@@ -26,7 +26,7 @@ auto semantic_addition_multi_gpu = LayerSemanticsParamType(
   nntrainer::AdditionLayerCL::type, {},
   LayerCreateSetPropertyOptions::AVAILABLE_FROM_APP_CONTEXT, false, 2);
 
-GTEST_PARAMETER_TEST(AdditionGPU, LayerSemantics,
+GTEST_PARAMETER_TEST(AdditionGPU, LayerSemanticsGpu,
                      ::testing::Values(semantic_addition_gpu,
                                        semantic_addition_multi_gpu));
 
@@ -40,11 +40,15 @@ auto addition_w32a32_2 = LayerGoldenTestParamType(
   "added_w32a32_2.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw",
   "fp32", "fp32");
 
-auto addition_w32a32_3 = LayerGoldenTestParamType(
-  nntrainer::createLayer<nntrainer::AdditionLayerCL>, {},
-  "20:55:50:55,20:55:50:55", "added_w32a32_3.nnlayergolden",
-  LayerGoldenTestParamOptions::DEFAULT, "nchw", "fp32", "fp32");
-
 GTEST_PARAMETER_TEST(AdditionGPU, LayerGoldenTest,
-                     ::testing::Values(addition_w32a32, addition_w32a32_2,
-                                       addition_w32a32_3));
+                     ::testing::Values(addition_w32a32, addition_w32a32_2));
+
+#ifdef ENABLE_FP16
+auto addition_w16a16 = LayerGoldenTestParamType(
+  nntrainer::createLayer<nntrainer::AdditionLayerCL>, {}, "2:3:3:3,2:3:3:3",
+  "added_w16a16.nnlayergolden", LayerGoldenTestParamOptions::DEFAULT, "nchw",
+  "fp16", "fp16");
+
+GTEST_PARAMETER_TEST(Addition16, LayerGoldenTest,
+                     ::testing::Values(addition_w16a16));
+#endif