diff --git a/lc0/src/neural/network_cudnn.cu b/lc0/src/neural/network_cudnn.cu index a65f01631..3efa57e38 100644 --- a/lc0/src/neural/network_cudnn.cu +++ b/lc0/src/neural/network_cudnn.cu @@ -111,10 +111,12 @@ class BaseLayer { size_t GetOutputSize(int N) const { return bpe_ * N * C * H * W; } // input2 is optional (skip connection) - virtual void Eval(int N, float *output, const float *input, - const float *input2, float *scratch, cudnnHandle_t cudnn, + virtual void Eval(int N, void *output, const void *input, + const void *input2, void *scratch, cudnnHandle_t cudnn, cublasHandle_t cublas) = 0; + static void enableFp16() { fp16_ = true; bpe_ = sizeof(half); } + static bool isFp16Enabled() { return fp16_; } protected: static bool fp16_; static size_t bpe_; // size of each element @@ -130,9 +132,9 @@ class ConvLayer : public BaseLayer { ConvLayer(BaseLayer *ip, int C, int H, int W, int size, int Cin, bool relu = false, bool bias = false); ~ConvLayer(); - void LoadWeights(float *pfilter, float *pBias = nullptr); - void Eval(int N, float *output, const float *input, const float *input2, - float *scratch, cudnnHandle_t cudnn, + void LoadWeights(float *pfilter, float *pBias , void *scratch); + void Eval(int N, void *output, const void *input, const void *input2, + void *scratch, cudnnHandle_t cudnn, cublasHandle_t cublas) override; private: @@ -141,8 +143,8 @@ class ConvLayer : public BaseLayer { const bool use_relu_; const bool use_bias_; - float *biases = nullptr; - float *weights = nullptr; + void *biases = nullptr; + void *weights = nullptr; cudnnFilterDescriptor_t filter_desc_; cudnnConvolutionDescriptor_t conv_desc_; @@ -157,8 +159,8 @@ class ConvLayer : public BaseLayer { class SoftMaxLayer : public BaseLayer { public: SoftMaxLayer(BaseLayer *ip); - void Eval(int N, float *output, const float *input, const float *input2, - float *scratch, cudnnHandle_t cudnn, + void Eval(int N, void *output, const void *input, const void *input2, + void *scratch, cudnnHandle_t cudnn, cublasHandle_t cublas) override; private: @@ -171,12 +173,15 @@ class BNLayer : public BaseLayer { ~BNLayer(); void LoadWeights(float *cpuMeans, float *cpuVar); - void Eval(int N, float *output, const float *input, const float *input2, - float *scratch, cudnnHandle_t cudnn, + void Eval(int N, void *output, const void *input, const void *input2, + void *scratch, cudnnHandle_t cudnn, cublasHandle_t cublas) override; private: const bool use_relu_; + + // always in float irrespective of fp16_ + // not much point in converting these to fp16 float *means_ = nullptr; float *variances_ = nullptr; }; @@ -187,17 +192,17 @@ class FCLayer : public BaseLayer { bool tanh = false); ~FCLayer(); - void LoadWeights(float *cpuWeight, float *cpuBias); - void Eval(int N, float *output, const float *input, const float *input2, - float *scratch, cudnnHandle_t cudnn, + void LoadWeights(float *cpuWeight, float *cpuBias, void *scratch); + void Eval(int N, void *output, const void *input, const void *input2, + void *scratch, cudnnHandle_t cudnn, cublasHandle_t cublas) override; private: const bool use_bias_; const bool use_relu_; const bool use_tanh_; - float *weights_ = nullptr; - float *biases_ = nullptr; + void *weights_ = nullptr; + void *biases_ = nullptr; }; // Need memory for 3 data buffers @@ -224,12 +229,12 @@ __global__ void addVectors_kernel(T *c, T *a, T *b, int size, int asize, int bsize, bool relu, bool useTanh) { int i = threadIdx.x + blockDim.x * blockIdx.x; if (i < size) { - T aVal = 0; - T bVal = 0; - if (a) aVal = a[i % asize]; - if (b) bVal = b[i % bsize]; + float aVal = 0; + float bVal = 0; + if (a) aVal = (float) (a[i % asize]); + if (b) bVal = (float) (b[i % bsize]); - T cVal = aVal + bVal; + float cVal = aVal + bVal; if (relu && (cVal < 0)) cVal = 0; @@ -242,7 +247,7 @@ __global__ void addVectors_kernel(T *c, T *a, T *b, int size, int asize, cVal = tanh(cVal); } - c[i] = cVal; + c[i] = (T) cVal; } } @@ -259,13 +264,87 @@ void addVectors(T *c, T *a, T *b, int size, int asize, int bsize, bool relu, reportCUDAErrors(cudaGetLastError()); } -__global__ void batchNormForward_kernel(float *output, const float *input, - const float *skipInput, int N, int C, + +__device__ half readNCHW(float *inputTensor, int n, int c, int h, int w, int Nin, int Cin, int H, int W) +{ + if (n >= Nin || c >= Cin) + return 0; + + int index; + index = n; + index *= Cin; + index += c; + index *= H; + index += h; + index *= W; + index += w; + + return (half)(inputTensor[index]); +} + +__global__ void fp32NCHWtofp16NHWC_kernel(half *outputTensor, float *inputTensor, int Nin, int Cin, int Nout, int Cout, int H, int W) +{ + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid >= Nout * Cout * H * W) + return; + + int index = tid; + + int c = (index % Cout); + index /= Cout; + int w = index % W; + index /= W; + int h = index % H; + index /= H; + int n = index; + + outputTensor[tid] = readNCHW(inputTensor, n, c, h, w, Nin, Cin, H, W); +} + +void fp32NCHWtofp16NHWC(half *outputTensor, float *inputTensor, int Nin, int Cin, int Nout, int Cout, int H, int W) +{ + size_t numElements = Nout * Cout * H * W; + const int blockSize = 256; + int blocks = divUp(numElements, blockSize); + fp32NCHWtofp16NHWC_kernel <<< blocks, blockSize >>> (outputTensor, inputTensor, Nin, Cin, Nout, Cout, H, W); +} + + +template +__global__ void copyTypeConverted_kernel(DstType *op, SrcType *ip, int N) +{ + int tid = blockIdx.x * blockDim.x + threadIdx.x; + + if (tid >= N) + return; + + DstType el = (DstType) ip[tid]; + op[tid] = el; +} + + +template +void copyTypeConverted(DstType *op, SrcType *ip, int N) +{ + const int blockSize = 256; + int blocks = divUp(N, blockSize); + copyTypeConverted_kernel <<< blocks, blockSize >>> (op, ip, N); +} + +template +__global__ void batchNormForward_kernel(T *output, const T *input, + const T *skipInput, int N, int C, int H, int W, const float *means, const float *varMultipliers, bool relu) { int index = threadIdx.x + blockDim.x * blockIdx.x; - int wIndex = (index / (H * W)) % C; + + int wIndex = 0; + if (sizeof(T) == sizeof(float)) + wIndex = (index / (H * W)) % C; // NCHW for fp32 + else + wIndex = index % C; // NHWC for fp16 float el = input[index]; float mean = means[wIndex]; @@ -274,17 +353,16 @@ __global__ void batchNormForward_kernel(float *output, const float *input, el -= mean; el *= varMulti; - // TODO: figure out order of relu and skip connection - if (skipInput) el += skipInput[index]; + if (skipInput) el += (float) skipInput[index]; if (relu && (el < 0)) el = 0; - output[index] = el; + output[index] = (T) el; } -// works only on NCHW tensors // each thread processes single element -void batchNormForward(float *output, const float *input, const float *skipInput, +template +void batchNormForward(T *output, const T *input, const T *skipInput, int N, int C, int H, int W, float *means, float *varMultipliers, bool relu) { int totalElements = N * C * H * W; @@ -297,8 +375,8 @@ void batchNormForward(float *output, const float *input, const float *skipInput, reportCUDAErrors(cudaGetLastError()); } -__global__ void expandPlanes_kernel(float *output, const uint64_t *masks, - const float *values, int n) { +__global__ void expandPlanes_kernel_Fp32_NCHW(float *output, const uint64_t *masks, + const float *values, int n) { // block size of 256, same mask/val for 64 consecutive threads constexpr int kNumShmemElments = 256 / 64; @@ -329,17 +407,48 @@ __global__ void expandPlanes_kernel(float *output, const uint64_t *masks, } output[index] = op; } -void expandPlanes(float *output, const uint64_t *masks, const float *values, + +void expandPlanes_Fp32_NCHW(float *output, const uint64_t *masks, const float *values, int n) { int threads = n * 8 * 8; // each thread writes a single element const int blockSize = 256; int blocks = divUp(threads, blockSize); + expandPlanes_kernel_Fp32_NCHW <<>>(output, masks, values, n); + reportCUDAErrors(cudaGetLastError()); +} - expandPlanes_kernel<<>>(output, masks, values, n); +// TODO: can optimize using shared memory if this becomes a bottleneck +__global__ void expandPlanes_kernel_Fp16_NHWC(half *output, const uint64_t *masks, + const float *values, int n) { - reportCUDAErrors(cudaGetLastError()); + const int index = threadIdx.x + blockDim.x * blockIdx.x; + if (index >= n * 8 * 8) return; + + const int planeIndex = index % kInputPlanes; + const int boardIndex = index / (kInputPlanes * 8 * 8); + const int sqIndex = (index / kInputPlanes) & 0x3F; + + uint64_t mask = masks[boardIndex * kInputPlanes + planeIndex]; + + half op = 0; + bool set = !!(mask & (1ull << sqIndex)); + if (set) { + float val = values[boardIndex * kInputPlanes + planeIndex]; + op = (half)val; + } + output[index] = op; +} + +void expandPlanes_Fp16_NHWC(half *output, const uint64_t *masks, const float *values, + int n) { + int threads = n * 8 * 8; // each thread writes a single element + const int blockSize = 256; + int blocks = divUp(threads, blockSize); + expandPlanes_kernel_Fp16_NHWC <<>>(output, masks, values, n); + reportCUDAErrors(cudaGetLastError()); } + BaseLayer::BaseLayer(int c, int h, int w, BaseLayer *ip) : C(c), H(h), W(w), input_(ip) {} @@ -348,8 +457,8 @@ SoftMaxLayer::SoftMaxLayer(BaseLayer *ip) cudnnCreateTensorDescriptor(&out_tensor_desc_); } -void SoftMaxLayer::Eval(int N, float *output, const float *input, - const float *input2, float *scratch, +void SoftMaxLayer::Eval(int N, void *output, const void *input, + const void *input2, void *scratch, cudnnHandle_t cudnn, cublasHandle_t cublas) { float alpha = 1.0f, beta = 0.0f; @@ -388,7 +497,7 @@ ConvLayer::ConvLayer(BaseLayer *ip, int C, int H, int W, int filter, int Cin, cudnnSetFilter4dDescriptor( filter_desc_, fp16_ ? CUDNN_DATA_HALF : CUDNN_DATA_FLOAT, fp16_ ? CUDNN_TENSOR_NHWC - : CUDNN_TENSOR_NCHW, // TODO: support fp16 evaluation + : CUDNN_TENSOR_NCHW, GetC(), Cin, filter_size_, filter_size_); reportCUDNNErrors(cudnnSetTensor4dDescriptor( @@ -398,13 +507,16 @@ ConvLayer::ConvLayer(BaseLayer *ip, int C, int H, int W, int filter, int Cin, int padding = filter_size_ / 2; const bool crossCorr = 1; - cudnnSetConvolution2dDescriptor( + reportCUDNNErrors(cudnnSetConvolution2dDescriptor( conv_desc_, padding, padding, 1, 1, 1, 1, crossCorr ? CUDNN_CROSS_CORRELATION : CUDNN_CONVOLUTION, - fp16_ ? CUDNN_DATA_HALF : CUDNN_DATA_FLOAT); + fp16_ ? CUDNN_DATA_HALF : CUDNN_DATA_FLOAT)); + + if (fp16_) + reportCUDNNErrors(cudnnSetConvolutionMathType(conv_desc_, CUDNN_TENSOR_OP_MATH)); // TODO: dynamic selection of algorithm! - if (C > 32) { + if ((C > 32) && (!fp16_)) { convAlgo = CUDNN_CONVOLUTION_FWD_ALGO_WINOGRAD_NONFUSED; } else { convAlgo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_PRECOMP_GEMM; @@ -419,22 +531,42 @@ ConvLayer::ConvLayer(BaseLayer *ip, int C, int H, int W, int filter, int Cin, } } -void ConvLayer::LoadWeights(float *pfilter, float *pBias) { - size_t weightSize = bpe_ * c_input_ * C * filter_size_ * filter_size_; - reportCUDAErrors( - cudaMemcpyAsync(weights, pfilter, weightSize, cudaMemcpyHostToDevice)); +void ConvLayer::LoadWeights(float *pfilter, float *pBias, void *scratch) { + size_t weightSize = sizeof(float) * c_input_ * C * filter_size_ * filter_size_; + size_t biasSize = sizeof(float) * C; + if (fp16_) { + // also need to convert from fp32 NCHW to fp16 NHWC + // first copy from CPU memory to scratch space in GPU memory + // and then do the type / layout conversion using a kernel + assert(scratch); + reportCUDAErrors( + cudaMemcpyAsync(scratch, pfilter, weightSize, cudaMemcpyHostToDevice)); + fp32NCHWtofp16NHWC((half *)weights, (float*)scratch, C, c_input_, C, c_input_, filter_size_, filter_size_); - size_t biasSize = bpe_ * C; - if (pBias) { + if (pBias) { + reportCUDAErrors( + cudaMemcpyAsync(scratch, pBias, biasSize, cudaMemcpyHostToDevice)); + + copyTypeConverted((half*)biases, (float *)scratch, C); + } + } + else + { reportCUDAErrors( - cudaMemcpyAsync(biases, pBias, biasSize, cudaMemcpyHostToDevice)); - } else { - reportCUDAErrors(cudaMemset(biases, biasSize, 0)); + cudaMemcpyAsync(weights, pfilter, weightSize, cudaMemcpyHostToDevice)); + + if (pBias) { + reportCUDAErrors( + cudaMemcpyAsync(biases, pBias, biasSize, cudaMemcpyHostToDevice)); + } + else { + reportCUDAErrors(cudaMemset(biases, biasSize, 0)); + } } } -void ConvLayer::Eval(int N, float *output, const float *input, - const float *input2, float *scratch, cudnnHandle_t cudnn, +void ConvLayer::Eval(int N, void *output, const void *input, + const void *input2, void *scratch, cudnnHandle_t cudnn, cublasHandle_t cublas) { reportCUDNNErrors(cudnnSetTensor4dDescriptor( out_tensor_desc_, fp16_ ? CUDNN_TENSOR_NHWC : CUDNN_TENSOR_NCHW, @@ -474,25 +606,29 @@ ConvLayer::~ConvLayer() { BNLayer::BNLayer(BaseLayer *ip, bool relu) : BaseLayer(ip->GetC(), ip->GetH(), ip->GetW(), ip), use_relu_(relu) { - size_t weightSize = bpe_ * C; + size_t weightSize = sizeof(float) * C; reportCUDAErrors(cudaMalloc(&means_, weightSize)); reportCUDAErrors(cudaMalloc(&variances_, weightSize)); } void BNLayer::LoadWeights(float *cpuMeans, float *cpuVar) { - size_t weightSize = bpe_ * C; + size_t weightSize = sizeof(float) * C; reportCUDAErrors( cudaMemcpyAsync(means_, cpuMeans, weightSize, cudaMemcpyHostToDevice)); reportCUDAErrors( cudaMemcpyAsync(variances_, cpuVar, weightSize, cudaMemcpyHostToDevice)); } -void BNLayer::Eval(int N, float *output, const float *input, - const float *input2, float *scratch, cudnnHandle_t cudnn, +void BNLayer::Eval(int N, void *output, const void *input, + const void *input2, void *scratch, cudnnHandle_t cudnn, cublasHandle_t cublas) { - batchNormForward(output, input, input2, N, C, H, W, means_, variances_, - use_relu_); + if (fp16_) + batchNormForward((half*)output, (const half*)input, (const half*) input2, + N, C, H, W, means_, variances_, use_relu_); + else + batchNormForward((float*)output, (const float*)input, (const float*) input2, + N, C, H, W, means_, variances_, use_relu_); } BNLayer::~BNLayer() { @@ -516,37 +652,169 @@ FCLayer::FCLayer(BaseLayer *ip, int C, int H, int W, bool relu, bool bias, } } -void FCLayer::LoadWeights(float *cpuWeight, float *cpuBias) { - size_t weightSize = - bpe_ * C * H * W * input_->GetC() * input_->GetH() * input_->GetW(); +void FCLayer::LoadWeights(float *cpuWeight, float *cpuBias, void *scratch) { + size_t numWeights = C * H * W * input_->GetC() * input_->GetH() * input_->GetW(); + size_t weightSize = sizeof(float) * numWeights; + size_t numBiases = C * H * W; + size_t biasSize = sizeof(float) * numBiases; - reportCUDAErrors( - cudaMemcpyAsync(weights_, cpuWeight, weightSize, cudaMemcpyHostToDevice)); - if (use_bias_) { - size_t biasSize = bpe_ * C * H * W; + if (fp16_) { + // also need to convert from fp32 to fp16 + assert(scratch); + reportCUDAErrors( + cudaMemcpyAsync(scratch, cpuWeight, weightSize, cudaMemcpyHostToDevice)); + + //copyTypeConverted((half*)weights_, (float *)scratch, numWeights); + fp32NCHWtofp16NHWC((half *)weights_, (float*)scratch, numBiases, input_->GetC(), numBiases, input_->GetC(), input_->GetH(), input_->GetW()); + + if (cpuBias) { + reportCUDAErrors( + cudaMemcpyAsync(scratch, cpuBias, biasSize, cudaMemcpyHostToDevice)); + copyTypeConverted((half*)biases_, (float *)scratch, numBiases); + } + } + else + { reportCUDAErrors( - cudaMemcpyAsync(biases_, cpuBias, biasSize, cudaMemcpyHostToDevice)); + cudaMemcpyAsync(weights_, cpuWeight, weightSize, cudaMemcpyHostToDevice)); + if (use_bias_) { + reportCUDAErrors( + cudaMemcpyAsync(biases_, cpuBias, biasSize, cudaMemcpyHostToDevice)); + } } } -void FCLayer::Eval(int N, float *outputTensor, const float *inputTensor, - const float *input2, float *scratch, cudnnHandle_t cudnn, +// taken from: https://devtalk.nvidia.com/default/topic/883897/error-when-trying-to-use-half-fp16-/ +/* +Copyright (c) 2015, Norbert Juffa +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions +are met: + +1. Redistributions of source code must retain the above copyright +notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright +notice, this list of conditions and the following disclaimer in the +documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +*/ + +half uint16_as_fp16(uint16_t a) +{ + half res; +#if defined (__cplusplus) + memcpy(&res, &a, sizeof(res)); +#else /* __cplusplus */ + volatile union { + __fp16 f; + uint16_t i; + } cvt; + cvt.i = a; + res = cvt.f; +#endif /* __cplusplus */ + return res; +} + +uint32_t fp32_as_uint32(float a) +{ + uint32_t res; +#if defined (__cplusplus) + memcpy(&res, &a, sizeof(res)); +#else /* __cplusplus */ + volatile union { + float f; + uint32_t i; + } cvt; + cvt.f = a; + res = cvt.i; +#endif /* __cplusplus */ + return res; +} + +/* host version of device function __float2half_rn() */ +half float2half_rn(float a) +{ + uint32_t ia = fp32_as_uint32(a); + uint16_t ir; + + ir = (ia >> 16) & 0x8000; + if ((ia & 0x7f800000) == 0x7f800000) { + if ((ia & 0x7fffffff) == 0x7f800000) { + ir |= 0x7c00; /* infinity */ + } + else { + ir = 0x7fff; /* canonical NaN */ + } + } + else if ((ia & 0x7f800000) >= 0x33000000) { + int shift = (int)((ia >> 23) & 0xff) - 127; + if (shift > 15) { + ir |= 0x7c00; /* infinity */ + } + else { + ia = (ia & 0x007fffff) | 0x00800000; /* extract mantissa */ + if (shift < -14) { /* denormal */ + ir |= ia >> (-1 - shift); + ia = ia << (32 - (-1 - shift)); + } + else { /* normal */ + ir |= ia >> (24 - 11); + ia = ia << (32 - (24 - 11)); + ir = ir + ((14 + shift) << 10); + } + /* IEEE-754 round to nearest of even */ + if ((ia > 0x80000000) || ((ia == 0x80000000) && (ir & 1))) { + ir++; + } + } + } + return uint16_as_fp16(ir); +} + +void FCLayer::Eval(int N, void *outputTensor, const void *inputTensor, + const void *input2, void *scratch, cudnnHandle_t cudnn, cublasHandle_t cublas) { - float alpha = 1.0f, beta = 0.0f; int numOutputs = C * H * W; int numInputs = input_->GetC() * input_->GetH() * input_->GetW(); if (fp16_) { - // TODO: implement this! - assert(0); + half alpha = float2half_rn(1.0f), beta = float2half_rn(0.0f); + reportCUBLASErrors(cublasHgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, + numOutputs, N, numInputs, + &alpha, + (half *)weights_, numInputs, + (half *)inputTensor, numInputs, + &beta, + (half *)outputTensor, numOutputs)); + + if (use_bias_ || use_relu_ || use_tanh_) { + addVectors((half*)outputTensor, (half*)biases_, (half*)outputTensor, numOutputs * N, + numOutputs, numOutputs * N, use_relu_, use_tanh_); + } + } else { + float alpha = 1.0f, beta = 0.0f; reportCUBLASErrors(cublasSgemm(cublas, CUBLAS_OP_T, CUBLAS_OP_N, numOutputs, - N, numInputs, &alpha, weights_, numInputs, - inputTensor, numInputs, &beta, outputTensor, + N, numInputs, &alpha, (float*)weights_, numInputs, + (float*)inputTensor, numInputs, &beta, (float*)outputTensor, numOutputs)); if (use_bias_ || use_relu_ || use_tanh_) { - addVectors(outputTensor, biases_, outputTensor, numOutputs * N, + addVectors((float*)outputTensor, (float *)biases_, (float*)outputTensor, numOutputs * N, numOutputs, numOutputs * N, use_relu_, use_tanh_); } } @@ -648,6 +916,7 @@ class CudnnNetwork : public Network { public: CudnnNetwork(Weights weights, const OptionsDict &options) { gpuId_ = options.GetOrDefault("gpu", 0); + int tryFp16 = options.GetOrDefault("fp16", 0); int totalGPUs; reportCUDAErrors(cudaGetDeviceCount(&totalGPUs)); @@ -661,6 +930,24 @@ class CudnnNetwork : public Network { reportCUDNNErrors(cudnnCreate(&cudnn_)); reportCUBLASErrors(cublasCreate(&cublas_)); + if (tryFp16) { + // check if the GPU support fp16 (Volta+) + // enable fp16 only if all devices on which we are trying to run + // have fp16 support (TODO: can fix this limitation if needed) + cudaDeviceProp deviceProp = {}; + cudaGetDeviceProperties(&deviceProp, gpuId_); + if (deviceProp.major >= 7) { + BaseLayer::enableFp16(); + reportCUBLASErrors(cublasSetMathMode(cublas_, CUBLAS_TENSOR_OP_MATH)); + } else { + throw Exception("Your GPU doesn't support FP16"); + } + } else { + if (BaseLayer::isFp16Enabled()) { + throw Exception("Different fp16 setting for different GPUs not yet supported"); + } + } + const int numInputPlanes = kInputPlanes; const int numFilters = weights.input.biases.size(); @@ -675,13 +962,19 @@ class CudnnNetwork : public Network { processConvBlock(weights.policy); processConvBlock(weights.value); - // 1. build the network, and copy the weights to GPU memory + // 1. allocate scratch space (used internally by cudnn to run convolutions, + // and also for format/layout conversion for weights) + reportCUDAErrors(cudaMalloc(&scratch_mem_, kCudaScratchSize)); + + + // 2. build the network, and copy the weights to GPU memory // input { auto inputConv = std::make_unique(nullptr, numFilters, 8, 8, 3, numInputPlanes, true, true); inputConv->LoadWeights(&weights.input.weights[0], - &weights.input.biases[0]); + &weights.input.biases[0], + scratch_mem_); network_.emplace_back(std::move(inputConv)); } @@ -690,13 +983,15 @@ class CudnnNetwork : public Network { auto conv1 = std::make_unique(getLastLayer(), numFilters, 8, 8, 3, numFilters, true, true); conv1->LoadWeights(&weights.residual[block].conv1.weights[0], - &weights.residual[block].conv1.biases[0]); + &weights.residual[block].conv1.biases[0], + scratch_mem_); network_.emplace_back(std::move(conv1)); auto conv2 = std::make_unique(getLastLayer(), numFilters, 8, 8, 3, numFilters, true, true); conv2->LoadWeights(&weights.residual[block].conv2.weights[0], - &weights.residual[block].conv2.biases[0]); + &weights.residual[block].conv2.biases[0], + scratch_mem_); network_.emplace_back(std::move(conv2)); } @@ -706,7 +1001,7 @@ class CudnnNetwork : public Network { { auto convPol = std::make_unique( resi_last_, weights.policy.bn_means.size(), 8, 8, 1, numFilters); - convPol->LoadWeights(&weights.policy.weights[0]); + convPol->LoadWeights(&weights.policy.weights[0], nullptr, scratch_mem_); network_.emplace_back(std::move(convPol)); auto BNPol = std::make_unique(getLastLayer(), true); @@ -716,7 +1011,7 @@ class CudnnNetwork : public Network { auto FCPol = std::make_unique( getLastLayer(), weights.ip_pol_b.size(), 1, 1, false, true); - FCPol->LoadWeights(&weights.ip_pol_w[0], &weights.ip_pol_b[0]); + FCPol->LoadWeights(&weights.ip_pol_w[0], &weights.ip_pol_b[0], scratch_mem_); network_.emplace_back(std::move(FCPol)); auto softmaxPol = std::make_unique(getLastLayer()); @@ -728,7 +1023,7 @@ class CudnnNetwork : public Network { { auto convVal = std::make_unique( resi_last_, weights.value.bn_means.size(), 8, 8, 1, numFilters); - convVal->LoadWeights(&weights.value.weights[0]); + convVal->LoadWeights(&weights.value.weights[0], nullptr, scratch_mem_); network_.emplace_back(std::move(convVal)); auto BNVal = std::make_unique(getLastLayer(), true); @@ -738,17 +1033,17 @@ class CudnnNetwork : public Network { auto FCVal1 = std::make_unique( getLastLayer(), weights.ip1_val_b.size(), 1, 1, true, true); - FCVal1->LoadWeights(&weights.ip1_val_w[0], &weights.ip1_val_b[0]); + FCVal1->LoadWeights(&weights.ip1_val_w[0], &weights.ip1_val_b[0], scratch_mem_); network_.emplace_back(std::move(FCVal1)); auto FCVal2 = std::make_unique(getLastLayer(), 1, 1, 1, false, true, true); - FCVal2->LoadWeights(&weights.ip2_val_w[0], &weights.ip2_val_b[0]); + FCVal2->LoadWeights(&weights.ip2_val_w[0], &weights.ip2_val_b[0], scratch_mem_); network_.emplace_back(std::move(FCVal2)); } value_out_ = getLastLayer(); - // 2. allocate GPU memory for running the network + // 3. allocate GPU memory for running the network // - three buffers of max size are enough (one to hold input, second to // hold output and third to hold skip connection's input) size_t maxSize = resi_last_->GetOutputSize(kMaxBatchSize); @@ -759,9 +1054,6 @@ class CudnnNetwork : public Network { // printf("Allocated %d bytes of GPU memory to run the network\n", 3 * // maxSize); - - // 3. allocate scratch space (used internally by cudnn to run convolutions) - reportCUDAErrors(cudaMalloc(&scratch_mem_, kCudaScratchSize)); } void forwardEval(InputsOutputs *io, int batchSize) { @@ -774,8 +1066,14 @@ class CudnnNetwork : public Network { // expand packed planes to full planes uint64_t *ipDataMasks = io->input_masks_mem_gpu_; float *ipDataValues = io->input_val_mem_gpu_; - expandPlanes(tensor_mem_[0], ipDataMasks, ipDataValues, - batchSize * kInputPlanes); + + if (BaseLayer::isFp16Enabled()) { + expandPlanes_Fp16_NHWC((half*)(tensor_mem_[0]), ipDataMasks, ipDataValues, + batchSize * kInputPlanes); + } else { + expandPlanes_Fp32_NCHW((float*)(tensor_mem_[0]), ipDataMasks, ipDataValues, + batchSize * kInputPlanes); + } float *opPol = io->op_policy_mem_gpu_; float *opVal = io->op_value_mem_gpu_; @@ -789,6 +1087,7 @@ class CudnnNetwork : public Network { for (int block = 0; block < numBlocks_; block++) { network_[l++]->Eval(batchSize, tensor_mem_[0], tensor_mem_[2], nullptr, scratch_mem_, cudnn_, cublas_); // conv1 + network_[l++]->Eval(batchSize, tensor_mem_[2], tensor_mem_[0], tensor_mem_[2], scratch_mem_, cudnn_, cublas_); // conv2 @@ -801,9 +1100,17 @@ class CudnnNetwork : public Network { scratch_mem_, cudnn_, cublas_); // pol BN network_[l++]->Eval(batchSize, tensor_mem_[0], tensor_mem_[1], nullptr, scratch_mem_, cudnn_, cublas_); // pol FC - network_[l++]->Eval(batchSize, opPol, tensor_mem_[0], nullptr, scratch_mem_, - cudnn_, - cublas_); // pol softmax // POLICY + if (BaseLayer::isFp16Enabled()) { + // TODO: consider softmax layer that writes directly to fp32 + network_[l++]->Eval(batchSize, tensor_mem_[1], tensor_mem_[0], nullptr, scratch_mem_, + cudnn_, + cublas_); // pol softmax + copyTypeConverted(opPol, (half *)(tensor_mem_[1]), batchSize * kNumOutputPolicy); // POLICY + } else { + network_[l++]->Eval(batchSize, opPol, tensor_mem_[0], nullptr, scratch_mem_, + cudnn_, + cublas_); // pol softmax // POLICY + } // value head network_[l++]->Eval(batchSize, tensor_mem_[0], tensor_mem_[2], nullptr, @@ -812,10 +1119,17 @@ class CudnnNetwork : public Network { scratch_mem_, cudnn_, cublas_); // value BN network_[l++]->Eval(batchSize, tensor_mem_[0], tensor_mem_[2], nullptr, scratch_mem_, cudnn_, cublas_); // value FC1 - network_[l++]->Eval(batchSize, opVal, tensor_mem_[0], nullptr, scratch_mem_, - cudnn_, - cublas_); // value FC2 // VALUE - + if (BaseLayer::isFp16Enabled()) { + // TODO: consider fusing the bias-add of FC2 with format conversion + network_[l++]->Eval(batchSize, tensor_mem_[2], tensor_mem_[0], nullptr, scratch_mem_, + cudnn_, + cublas_); // value FC2 + copyTypeConverted(opVal, (half *)(tensor_mem_[2]), batchSize); // VALUE + } else { + network_[l++]->Eval(batchSize, opVal, tensor_mem_[0], nullptr, scratch_mem_, + cudnn_, + cublas_); // value FC2 // VALUE + } reportCUDAErrors(cudaDeviceSynchronize()); #if DEBUG_RAW_NPS == 1 @@ -897,8 +1211,8 @@ class CudnnNetwork : public Network { BaseLayer *policy_out_; BaseLayer *value_out_; - float *tensor_mem_[3]; - float *scratch_mem_; + void *tensor_mem_[3]; + void *scratch_mem_; mutable std::mutex inputs_outputs_lock_; std::list> free_inputs_outputs_;