diff --git a/nntrainer/opencl/opencl_context_manager.cpp b/nntrainer/opencl/opencl_context_manager.cpp index 4488927b53..e6cf0201d4 100644 --- a/nntrainer/opencl/opencl_context_manager.cpp +++ b/nntrainer/opencl/opencl_context_manager.cpp @@ -155,6 +155,31 @@ bool ContextManager::CreateDefaultGPUDevice() { device_id_ = devices[0]; platform_id_ = platform_id_; +#ifdef ENABLE_FP16 + // check for fp16 (half) support available on device + // getting extensions + size_t extension_size; + status = + clGetDeviceInfo(device_id_, CL_DEVICE_EXTENSIONS, 0, NULL, &extension_size); + if (status != CL_SUCCESS) { + ml_loge("clGetDeviceInfo returned %d", status); + return false; + } + + char extensions[extension_size]; + status = clGetDeviceInfo(device_id_, CL_DEVICE_EXTENSIONS, extension_size, + extensions, NULL); + if (status != CL_SUCCESS) { + ml_loge("clGetDeviceInfo returned %d", status); + return false; + } + + if (std::string(extensions).find("cl_khr_fp16") == std::string::npos) { + ml_loge("fp16 (half) is not supported by device"); + return false; + } +#endif + return true; } diff --git a/nntrainer/opencl/opencl_loader.cpp b/nntrainer/opencl/opencl_loader.cpp index 93a44a45bd..34101b6a4c 100644 --- a/nntrainer/opencl/opencl_loader.cpp +++ b/nntrainer/opencl/opencl_loader.cpp @@ -67,6 +67,7 @@ bool LoadOpenCL() { void LoadOpenCLFunctions(void *libopencl) { LoadFunction(clGetPlatformIDs); LoadFunction(clGetDeviceIDs); + LoadFunction(clGetDeviceInfo); LoadFunction(clCreateContext); LoadFunction(clCreateCommandQueue); LoadFunction(clCreateBuffer); @@ -91,6 +92,7 @@ void LoadOpenCLFunctions(void *libopencl) { PFN_clGetPlatformIDs clGetPlatformIDs; PFN_clGetDeviceIDs clGetDeviceIDs; +PFN_clGetDeviceInfo clGetDeviceInfo; PFN_clCreateContext clCreateContext; PFN_clCreateCommandQueue clCreateCommandQueue; PFN_clCreateBuffer clCreateBuffer; diff --git a/nntrainer/opencl/opencl_loader.h b/nntrainer/opencl/opencl_loader.h index cfbeb629c6..99142d5b12 100644 --- a/nntrainer/opencl/opencl_loader.h +++ b/nntrainer/opencl/opencl_loader.h @@ -38,6 +38,11 @@ typedef cl_int(CL_API_CALL *PFN_clGetDeviceIDs)( cl_uint /**< num_entries */, cl_device_id * /**< devices */, cl_uint * /**< num_devices */); +typedef cl_int(CL_API_CALL *PFN_clGetDeviceInfo)( + cl_device_id /**< device */, cl_device_info /**< param_name */, + size_t /**< param_value_size */, void * /**< param_value */, + size_t * /**< param_value_size_ret */); + typedef cl_context(CL_API_CALL *PFN_clCreateContext)( const cl_context_properties * /**< properties */, cl_uint /**< num_devices */, const cl_device_id * /**< devices */, @@ -133,6 +138,7 @@ typedef cl_int(CL_API_CALL *PFN_clReleaseMemObject)(cl_mem /**< memobj */); extern PFN_clGetPlatformIDs clGetPlatformIDs; extern PFN_clGetDeviceIDs clGetDeviceIDs; +extern PFN_clGetDeviceInfo clGetDeviceInfo; extern PFN_clCreateContext clCreateContext; extern PFN_clCreateCommandQueue clCreateCommandQueue; extern PFN_clCreateBuffer clCreateBuffer;