diff --git a/README.md b/README.md index ed7bcca..5001fb6 100644 --- a/README.md +++ b/README.md @@ -844,3 +844,49 @@ To compile and run the benchmark, please use the following command: make sparse EIGEN_PATH= bin/sparse_bench ``` + +# AMD Benchmarks + +## Prerequisites +* A ROCm enabled platform, more info [here](https://rocm.github.io/install.html). +* [MIOpen](https://github.com/ROCmSoftwarePlatform/MIOpen) - HIP backend of MIOpen is required. +* [rocBLAS](https://github.com/ROCmSoftwarePlatform/rocBLAS) + +At present only `fp32 train` benchmarks are enabled. + +## Compiling + +The `Makefile` in `code/amd` is for an AMD `gfx900` GPU. To benchmark other generations, please modify the `Makefile` accordingly. + +Setting your enviroment variables before compiling/running: + +``` +export PATH=PATH_TO_ROCM/bin:$PATH +export CPATH=PATH_TO_MIOPEN/include:$CPATH +export LIBRARY_PATH=PATH_TO_MIOPEN/lib:$LIBRARY_PATH +export LD_LIBRARY_PATH=PATH_TO_MIOPEN/lib:PATH_TO_MIOPENGEMM/lib:$LD_LIBRARY_PATH +``` + +To compile the convolution, RNNs and GEMM benchmarks, run: + +``` +make conv rnn gemm +``` + +## Running the Benchmarks +After successful compilation, the executables will be generated in the `bin` directory. + +To benchmark convolutions: +``` +bin/conv_bench +``` + +To benchmark RNN: +``` +bin/rnn_bench +``` + +To benchmark GEMM: +``` +bin/gemm_bench +``` diff --git a/code/amd/Makefile b/code/amd/Makefile new file mode 100644 index 0000000..6860a55 --- /dev/null +++ b/code/amd/Makefile @@ -0,0 +1,35 @@ +SOURCE_DIR?=. +BIN_DIR?=bin +MKDIR=mkdir -p + +#hipcc +HIPCC=/opt/rocm/bin/hipcc + +#BLAS +ROCBLAS_LIB=rocblas + +#CONV +MIOPEN_LIB?=MIOpen + +#DeepBench +DEEPBENCH_INC=${SOURCE_DIR}/../kernels + +all: conv rnn gemm + +#OPT=-g -O0 -fsanitize=undefined -fno-omit-frame-pointer +OPT=-O3 + +conv: + $(MKDIR) $(BIN_DIR) + $(HIPCC) ${SOURCE_DIR}/conv_bench_rocm.cpp -o $(BIN_DIR)/conv_bench -I$(DEEPBENCH_INC) -l$(MIOPEN_LIB) $(OPT) -std=c++11 --amdgpu-target=gfx900 + +rnn: + $(MKDIR) $(BIN_DIR) + $(HIPCC) ${SOURCE_DIR}/rnn_bench_rocm.cpp -o $(BIN_DIR)/rnn_bench -I$(DEEPBENCH_INC) -l$(MIOPEN_LIB) $(OPT) -std=c++11 --amdgpu-target=gfx900 + +gemm: + $(MKDIR) $(BIN_DIR) + $(HIPCC) ${SOURCE_DIR}/gemm_bench.cpp -o $(BIN_DIR)/gemm_bench -I$(DEEPBENCH_INC) -l$(ROCBLAS_LIB) $(OPT) -std=c++11 --amdgpu-target=gfx900 + +clean: + rm -rf $(BIN_DIR) diff --git a/code/amd/conv_bench_rocm.cpp b/code/amd/conv_bench_rocm.cpp new file mode 100644 index 0000000..be92829 --- /dev/null +++ b/code/amd/conv_bench_rocm.cpp @@ -0,0 +1,419 @@ +#include +#include +#include +#include +#include + +#include +#include + +#include "tensor.h" +#include "miopen_helper.h" +#include "conv_problems.h" + +template +class miopenCNN { + TensorDescriptor4d x_desc_; + TensorDescriptor4d h_desc_; + + FilterDescriptor4d w_desc_; + + std::vector output_dims_; + int num_repeats_; + + size_t fwd_workspace_size_; + size_t bwd_inputs_workspace_size_; + size_t bwd_params_workspace_size_; + + Tensor fwd_workspace_; + Tensor bwd_inputs_workspace_; + Tensor bwd_params_workspace_; + + Tensor h; + + miopenConvFwdAlgorithm_t fwd_algo_; + miopenConvBwdDataAlgorithm_t bwd_inputs_algo_; + miopenConvBwdWeightsAlgorithm_t bwd_params_algo_; + + const float alpha_ = 1.f; + const float beta_ = 0.f; + + ConvolutionDescriptor conv_desc_; + MIOpenHandle miopen_handle_; + +public: + + miopenCNN(int _w, int _h, int c, int n, int k, int r, int s, + int pad_w, int pad_h, int wstride, int hstride, Tensor x, Tensor w) + : + miopen_handle_(), + x_desc_(n, c, _h, _w), + w_desc_(k, c, r, s), + conv_desc_(pad_h, pad_w, hstride, wstride) + { + int out_h, out_w, out_c, out_n; + + // Get output dimensions + CHECK_MIOPEN_ERROR(miopenGetConvolutionForwardOutputDim(conv_desc_.desc(), + x_desc_.desc(), + w_desc_.desc(), + &out_n, + &out_c, + &out_h, + &out_w)); + + h_desc_ = TensorDescriptor4d(out_n, out_c, out_h, out_w); + + output_dims_ = {out_w, out_h, out_c, out_n}; + + h = zeros(output_dims_); + + + // Set fwd workspace size + CHECK_MIOPEN_ERROR(miopenConvolutionForwardGetWorkSpaceSize( + miopen_handle_.handle(), + w_desc_.desc(), + x_desc_.desc(), + conv_desc_.desc(), + h_desc_.desc(), + &fwd_workspace_size_)); + + std::vector u = std::vector{static_cast(fwd_workspace_size_ / sizeof(float)), 1}; + + fwd_workspace_ = zeros(u); + + const int requestAlgoCount = 1; + int returnedAlgoCount; + miopenConvAlgoPerf_t perfResults; + + CHECK_MIOPEN_ERROR(miopenFindConvolutionForwardAlgorithm( + miopen_handle_.handle(), + x_desc_.desc(), + x.begin(), + w_desc_.desc(), + w.begin(), + conv_desc_.desc(), + h_desc_.desc(), + h.begin(), + requestAlgoCount, + &returnedAlgoCount, + &perfResults, + fwd_workspace_.begin(), + fwd_workspace_size_, + false + )); + + fwd_algo_ = perfResults.fwd_algo; + + + CHECK_MIOPEN_ERROR(miopenConvolutionBackwardWeightsGetWorkSpaceSize( + miopen_handle_.handle(), + h_desc_.desc(), + x_desc_.desc(), + conv_desc_.desc(), + w_desc_.desc(), + &bwd_params_workspace_size_)); + u = std::vector{static_cast(bwd_params_workspace_size_ / sizeof(float)), 1}; + bwd_params_workspace_ = zeros(u); + + CHECK_MIOPEN_ERROR(miopenFindConvolutionBackwardWeightsAlgorithm( + miopen_handle_.handle(), + h_desc_.desc(), + h.begin(), + x_desc_.desc(), + x.begin(), + conv_desc_.desc(), + w_desc_.desc(), + w.begin(), + requestAlgoCount, + &returnedAlgoCount, + &perfResults, + bwd_params_workspace_.begin(), + bwd_params_workspace_size_, + false + )); + + bwd_params_algo_ = perfResults.bwd_weights_algo; + + CHECK_MIOPEN_ERROR(miopenConvolutionBackwardDataGetWorkSpaceSize( + miopen_handle_.handle(), + h_desc_.desc(), + w_desc_.desc(), + conv_desc_.desc(), + x_desc_.desc(), + &bwd_inputs_workspace_size_)); + + u = std::vector{static_cast(bwd_inputs_workspace_size_ / sizeof(float)), 1}; + bwd_inputs_workspace_ = zeros(u); + + CHECK_MIOPEN_ERROR(miopenFindConvolutionBackwardDataAlgorithm( + miopen_handle_.handle(), + h_desc_.desc(), + h.begin(), + w_desc_.desc(), + w.begin(), + conv_desc_.desc(), + x_desc_.desc(), + x.begin(), + requestAlgoCount, + &returnedAlgoCount, + &perfResults, + bwd_inputs_workspace_.begin(), + bwd_inputs_workspace_size_, + false + )); + + bwd_inputs_algo_ = perfResults.bwd_data_algo; + + } + + Tensor getOutputTensor(){ return h; } + + std::vector get_output_dims() { return output_dims_; } + + std::string get_fwd_algo_string() { + if (fwd_algo_ == miopenConvolutionFwdAlgoGEMM) + return " ConvolutionFwdAlgoGEMM"; + else if (fwd_algo_ == miopenConvolutionFwdAlgoDirect) + return " ConvolutionFwdAlgoDirect"; + else if (fwd_algo_ == miopenConvolutionFwdAlgoFFT) + return " ConvolutionFwdAlgoFFT"; + else if (fwd_algo_ == miopenConvolutionFwdAlgoWinograd) + return " ConvolutionFwdAlgoWinograd"; + else { + std::stringstream ss; + ss << "Illegal algorithm passed to get_fwd_algo_string. Algo: " << fwd_algo_ << std::endl; + throw std::runtime_error(ss.str()); + } + } + + + void forward(Tensor x, Tensor filter, Tensor h) { + + // Convolution forward. + CHECK_MIOPEN_ERROR(miopenConvolutionForward(miopen_handle_.handle(), + &alpha_, + x_desc_.desc(), + x.begin(), + w_desc_.desc(), + filter.begin(), + conv_desc_.desc(), + fwd_algo_, + &beta_, + h_desc_.desc(), + h.begin(), + fwd_workspace_.begin(), + fwd_workspace_size_ + )); + + } + + void backward_params(Tensor x, Tensor delta, Tensor dW) { + + CHECK_MIOPEN_ERROR(miopenConvolutionBackwardWeights(miopen_handle_.handle(), + &alpha_, + h_desc_.desc(), + delta.begin(), + x_desc_.desc(), + x.begin(), + conv_desc_.desc(), + bwd_params_algo_, + &beta_, + w_desc_.desc(), + dW.begin(), + bwd_params_workspace_.begin(), + bwd_params_workspace_size_ + )); + + + } + + void backward_inputs(Tensor filter, Tensor delta, Tensor dX) { + + CHECK_MIOPEN_ERROR(miopenConvolutionBackwardData(miopen_handle_.handle(), + &alpha_, + h_desc_.desc(), + delta.begin(), + w_desc_.desc(), + filter.begin(), + conv_desc_.desc(), + bwd_inputs_algo_, + &beta_, + x_desc_.desc(), + dX.begin(), + bwd_inputs_workspace_.begin(), + bwd_inputs_workspace_size_ + )); + + } +}; + +template +std::tuple time_cnn( + int k, int c, int r, int s, + int n, int h, int w, + int pad_h, int pad_w, + int hstride, int wstride, + int num_repeats + ) { + + + // Allocate memory for filter + auto filter = rand(std::vector{r, s, c, k}); + + // Allocate memory for input + auto input = rand(std::vector{w, h, c, n}); + miopenCNN cnn(w, h, c, n, k, r, s, pad_w, pad_h, wstride, hstride, input, filter); + + // Allocate memory for output tensor + auto output = cnn.getOutputTensor(); + + std::string fwd_algo_s = cnn.get_fwd_algo_string(); + + //Warm up + cnn.forward(input, filter, output); + + hipDeviceSynchronize(); + auto start = std::chrono::steady_clock::now(); + + for (int i = 0; i < num_repeats; ++i) { + cnn.forward(input, filter, output); + } + + hipDeviceSynchronize(); + auto end = std::chrono::steady_clock::now(); + int fwd_time = static_cast(std::chrono::duration(end - start).count() / num_repeats); + + // Allocate memory for backward pass wrt weights + auto delta = rand(cnn.get_output_dims()); + auto dW = zeros(std::vector{r, s, c, k}); + + // Warm up backward + cnn.backward_params(input, delta, dW); + + hipDeviceSynchronize(); + start = std::chrono::steady_clock::now(); + + for (int i = 0; i < num_repeats; ++i) { + // Backward pass wrt weights + cnn.backward_params(input, delta, dW); + } + + hipDeviceSynchronize(); + end = std::chrono::steady_clock::now(); + + int bwd_params_time = static_cast(std::chrono::duration(end - start).count() / num_repeats); + + //Allocate memory for backward pass wrt inputs + auto dX = zeros(std::vector{w, h, c, n}); + + //Warm up backward inputs + cnn.backward_inputs(filter, delta, dX); + + hipDeviceSynchronize(); + start = std::chrono::steady_clock::now(); + + for (int i = 0; i < num_repeats; ++i) { + // Backward pass wrt inputs + cnn.backward_inputs(filter, delta, dX); + + } + + hipDeviceSynchronize(); + end = std::chrono::steady_clock::now(); + + int bwd_inputs_time = static_cast(std::chrono::duration(end - start).count() / num_repeats); + + return std::tuple(fwd_time, bwd_inputs_time, bwd_params_time, fwd_algo_s); + +} + + +int main(int argc, char **argv) { + + int num_repeats = 100; + std::string precision ="float"; + + hipFree(0); + + if (argc > 1) + { + num_repeats = atoi(argv[1]); + precision = argv[2]; + } + + std::cout << std::setw(30) << "Times" << std::endl; + std::cout << std::setfill('-') << std::setw(190) << "-" << std::endl; + std::cout << std::setfill(' '); + std::cout << " w h c n k f_w f_h pad_w pad_h stride_w stride_h fwd_time (usec) bwd_inputs_time (usec) bwd_params_time (usec) total_time (usec) fwd_algo " << std::endl; + std::cout << std::setfill('-') << std::setw(190) << "-" << std::endl; + std::cout << std::setfill(' '); + + int total_fwd_time=0, total_bwd_inputs_time=0, total_bwd_params_time=0; + for (const auto &problem : training_set) { + + // Filter parameters + int k, c, r, s; // r - filter_h (f_h), s - filter_w (f_w) + + // Input parameters + int n, w, h; + + // Padding + int pad_w, pad_h; + + // Stride + int wstride, hstride; + + std::tie(w, h, c, n, k, s, r, pad_w, pad_h, wstride, hstride) = problem; + + int fwd_time, bwd_inputs_time, bwd_params_time; + std::string fwd_algo_s; + + if( precision == "float" ) + { + std::tie(fwd_time, bwd_inputs_time, bwd_params_time, fwd_algo_s) = + time_cnn(k, c, r, s, n, h, w, pad_h, pad_w, hstride, wstride, num_repeats); + } + else + { + throw std::runtime_error("unknown precision"); + } + + std::cout << std::setw(5) << w; + std::cout << std::setw(7) << h; + std::cout << std::setw(7) << c; + std::cout << std::setw(7) << n; + std::cout << std::setw(7) << k; + std::cout << std::setw(7) << s; + std::cout << std::setw(7) << r; + std::cout << std::setw(7) << pad_w; + std::cout << std::setw(8) << pad_h; + std::cout << std::setw(10) << wstride; + std::cout << std::setw(10) << hstride; + std::cout << std::setw(14) << std::setprecision(7) << fwd_time; + std::cout << std::setw(24) << std::setprecision(7) << bwd_inputs_time; + std::cout << std::setw(24) << std::setprecision(7) << bwd_params_time; + std::cout << std::setw(19) << std::setprecision(8) << fwd_time + bwd_inputs_time + bwd_params_time; + + std::cout << std::setw(25) << fwd_algo_s; + + std::cout << std::endl; + + total_fwd_time += fwd_time; + total_bwd_inputs_time += bwd_inputs_time; + total_bwd_params_time += bwd_params_time; + + } + + std::cout << std::setw(82) << "Totals" ; + std::cout << std::setw(14) << std::setprecision(7) << total_fwd_time; + std::cout << std::setw(24) << std::setprecision(7) << total_bwd_inputs_time; + std::cout << std::setw(24) << std::setprecision(7) << total_bwd_params_time; + std::cout << std::setw(19) << std::setprecision(8) << total_fwd_time + total_bwd_inputs_time + total_bwd_params_time; + std::cout << std::endl; + + return 0; + +} + + diff --git a/code/amd/gemm_bench.cpp b/code/amd/gemm_bench.cpp new file mode 100644 index 0000000..b1fb890 --- /dev/null +++ b/code/amd/gemm_bench.cpp @@ -0,0 +1,99 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "tensor.h" +#include "gemm_problems.h" + +int time_gemm(Tensor A, Tensor B, Tensor C, bool a_t, bool b_t, rocblas_handle handle) { + const float alpha = 1.f / static_cast(A.dims()[1]); + const float beta = 1.f; + + int m = C.dims()[0]; + int k = a_t ? A.dims()[0] : A.dims()[1]; + int n = C.dims()[1]; + + int numRepeats = std::max(std::ceil(1e11 / (m * k * n)), 10.); + + // Warm up + rocblas_status stat = rocblas_sgemm( + handle, + a_t ? rocblas_operation_transpose : rocblas_operation_none, + b_t ? rocblas_operation_transpose : rocblas_operation_none, + m, n, k, + &alpha, + A.begin(), A.dims()[0], + B.begin(), B.dims()[0], + &beta, + C.begin(), C.dims()[0] ); + + if (stat != rocblas_status_success) { + throw std::runtime_error("sgemm failed"); + } + + hipDeviceSynchronize(); + + auto start = std::chrono::steady_clock::now(); + + for (int i = 0; i < numRepeats; ++i) { + rocblas_status stat = rocblas_sgemm( + handle, + a_t ? rocblas_operation_transpose : rocblas_operation_none, + b_t ? rocblas_operation_transpose : rocblas_operation_none, + m, n, k, + &alpha, + A.begin(), A.dims()[0], + B.begin(), B.dims()[0], + &beta, + C.begin(), C.dims()[0] ); + if (stat != rocblas_status_success) { + throw std::runtime_error("sgemm failed"); + } + } + hipDeviceSynchronize(); + + auto end = std::chrono::steady_clock::now(); + + return static_cast(std::chrono::duration(end - start).count() / numRepeats); + +} + +int main(int argc, char **argv) { + hipFree(0); + hipSetDevice(1); + rocblas_handle handle; + rocblas_create_handle(&handle); + + + std::cout << std::setw(30) << "Times" << std::endl; + std::cout << std::setfill('-') << std::setw(88) << "-" << std::endl; + std::cout << std::setfill(' '); + std::cout << " m n k a_t b_t time (usec) " << std::endl; + for (const auto &problem : training_set) { + int m, n, k; + bool a_t, b_t; + std::tie(m, n, k, a_t, b_t) = problem; + + auto a = rand({a_t ? k : m, a_t ? m : k}); + auto b = rand({b_t ? n : k, b_t ? k : n}); + auto c = zeros({m, n}); + + std::cout << std::setw(7) << m; + std::cout << std::setw(7) << n; + std::cout << std::setw(7) << k; + std::cout << std::setw(7) << (a_t ? "t" : "n"); + std::cout << std::setw(7) << (b_t ? "t" : "n"); + std::cout << std::setw(13) << std::setprecision(6) << time_gemm(a, b, c, a_t, b_t, handle); + std::cout << std::endl; + } + + rocblas_destroy_handle(handle); + return 0; +} + diff --git a/code/amd/hip_helper.h b/code/amd/hip_helper.h new file mode 100644 index 0000000..ec3632d --- /dev/null +++ b/code/amd/hip_helper.h @@ -0,0 +1,16 @@ +#pragma once + +#include + +#include + +void throw_hip_error(hipError_t ret, int line, const char* filename) { + if (ret != hipSuccess) { + std::stringstream ss; + ss << "HIP failure: " << hipGetErrorString(ret) << + " in " << filename << " at line: " << line << std::endl; + throw std::runtime_error(ss.str()); + } +} + +#define CHECK_HIP_ERROR(ret) throw_hip_error(ret, __LINE__, __FILE__) diff --git a/code/amd/hip_stl.h b/code/amd/hip_stl.h new file mode 100644 index 0000000..d15bd84 --- /dev/null +++ b/code/amd/hip_stl.h @@ -0,0 +1,17 @@ +#pragma once + +namespace force { + +template +class device_ptr{ +public: + T *ptr; + device_ptr(T *ptr) : ptr(ptr) {} +}; + +template +void fill(device_ptr begin, device_ptr end, U val){ + +} + +} diff --git a/code/amd/miopen_helper.h b/code/amd/miopen_helper.h new file mode 100644 index 0000000..7cc495e --- /dev/null +++ b/code/amd/miopen_helper.h @@ -0,0 +1,264 @@ +#pragma once + +#include +#include +#include + +#include + +#include "hip_helper.h" + +void throw_miopen_err(miopenStatus_t status, int line, const char* filename) { + if (status != miopenStatusSuccess) { + std::stringstream ss; + ss << "MIOPEN failure: " << status << + " in " << filename << " at line: " << line << std::endl; + throw std::runtime_error(ss.str()); + } +} + +#define CHECK_MIOPEN_ERROR(status) throw_miopen_err(status, __LINE__, __FILE__) + +class MIOpenHandle { + hipStream_t stream_; + std::shared_ptr handle_; + + struct MIOpenHandleDeleter { + void operator()(miopenHandle_t * handle) { + miopenDestroy(*handle); + delete handle; + } + }; + +public: + MIOpenHandle() : handle_(new miopenHandle_t, MIOpenHandleDeleter()) { + CHECK_HIP_ERROR(hipStreamCreate(&stream_)); + CHECK_MIOPEN_ERROR(miopenCreateWithStream(handle_.get(), stream_)); + } + ~MIOpenHandle() { + CHECK_HIP_ERROR(hipStreamDestroy(stream_)); + } + + miopenHandle_t handle() const { return *handle_; }; +}; + +template +class TensorDescriptor { + std::shared_ptr desc_; + + struct TensorDescriptorDeleter { + void operator()(miopenTensorDescriptor_t * desc) { + miopenDestroyTensorDescriptor(*desc); + delete desc; + } + }; + +public: + TensorDescriptor() + { + miopenTensorDescriptor_t * desc = new miopenTensorDescriptor_t; + CHECK_MIOPEN_ERROR(miopenCreateTensorDescriptor(desc)); + + desc_.reset(desc, TensorDescriptorDeleter()); + } + + TensorDescriptor(std::vector lens, + std::vector strides) { + miopenDataType_t type; + if (std::is_same::value) + type = miopenFloat; + else + throw std::runtime_error("Unknown type"); + + miopenTensorDescriptor_t * desc = new miopenTensorDescriptor_t; + CHECK_MIOPEN_ERROR(miopenCreateTensorDescriptor(desc)); + CHECK_MIOPEN_ERROR(miopenSetTensorDescriptor(*desc, type, static_cast(lens.size()), &lens[0], &strides[0])); + + desc_.reset(desc, TensorDescriptorDeleter()); + } + + miopenTensorDescriptor_t desc() const { return *desc_; } + +}; + +template +class TensorDescriptorArray +{ + std::shared_ptr desc_array_; + + struct ArrayDeleter + { + int num_; + ArrayDeleter(int num) : num_(num) {} + + void operator()(miopenTensorDescriptor_t *desc_array) { + for (int i = 0; i < num_; ++i) { + miopenDestroyTensorDescriptor(desc_array[i]); + } + + delete[] desc_array; + } + }; + +public: + + TensorDescriptorArray(std::vector lens, + std::vector strides, + int num) + { + miopenDataType_t type; + if (std::is_same::value) + type = miopenFloat; + else + throw std::runtime_error("Unknown type"); + + miopenTensorDescriptor_t * desc_array = new miopenTensorDescriptor_t[num]; + + for (int i = 0; i < num; ++i) + { + CHECK_MIOPEN_ERROR(miopenCreateTensorDescriptor(&desc_array[i])); + CHECK_MIOPEN_ERROR(miopenSetTensorDescriptor(desc_array[i], type, lens.size(), + &lens[0], &strides[0]) ); + } + + desc_array_.reset(desc_array, ArrayDeleter(num)); + } + + miopenTensorDescriptor_t * ptr() const { return desc_array_.get(); } +}; + +template +class TensorDescriptor4d { + std::shared_ptr desc_; + + struct TensorDescriptor4dDeleter { + void operator()(miopenTensorDescriptor_t * desc) { + miopenDestroyTensorDescriptor(*desc); + delete desc; + } + }; + +public: + + TensorDescriptor4d() {} + TensorDescriptor4d(const int n, const int c, const int h, const int w) { + miopenDataType_t type; + if (std::is_same::value) + type = miopenFloat; + else + throw std::runtime_error("Unknown type"); + + miopenTensorDescriptor_t * desc = new miopenTensorDescriptor_t; + CHECK_MIOPEN_ERROR(miopenCreateTensorDescriptor(desc)); + CHECK_MIOPEN_ERROR(miopenSet4dTensorDescriptor(*desc, + type, + n, + c, + h, + w)); + + desc_.reset(desc, TensorDescriptor4dDeleter()); + } + + miopenTensorDescriptor_t desc() const { return *desc_; } + +}; + +template +class FilterDescriptor4d { + std::shared_ptr desc_; + + struct FilterDescriptor4dDeleter { + void operator()(miopenTensorDescriptor_t * desc) { + miopenDestroyTensorDescriptor(*desc); + delete desc; + } + }; + +public: + FilterDescriptor4d(int k, int c, int h, int w) { + miopenDataType_t type; + if (std::is_same::value) + type = miopenFloat; + else + throw std::runtime_error("Unknown type"); + + miopenTensorDescriptor_t * desc = new miopenTensorDescriptor_t; + CHECK_MIOPEN_ERROR(miopenCreateTensorDescriptor(desc)); + CHECK_MIOPEN_ERROR(miopenSet4dTensorDescriptor(*desc, type, k, c, h, w)); + + desc_.reset(desc, FilterDescriptor4dDeleter()); + } + + miopenTensorDescriptor_t desc() const { return *desc_; } + +}; + +class ConvolutionDescriptor { + std::shared_ptr desc_; + + struct ConvolutionDescriptorDeleter { + void operator()(miopenConvolutionDescriptor_t * desc) { + miopenDestroyConvolutionDescriptor(*desc); + delete desc; + } + }; +public: + + + ConvolutionDescriptor(int pad_h, int pad_w, int hstride, int wstride) : + desc_(new miopenConvolutionDescriptor_t, ConvolutionDescriptorDeleter()) { + + CHECK_MIOPEN_ERROR(miopenCreateConvolutionDescriptor(desc_.get())); + CHECK_MIOPEN_ERROR(miopenInitConvolutionDescriptor(*desc_, + miopenConvolution, + pad_h, + pad_w, + hstride, + wstride, + 1, + 1)); + } + + miopenConvolutionDescriptor_t desc() const { return *desc_; }; + +}; + +class RNNDescriptor { + std::shared_ptr desc_; + + struct RNNDescriptorDeleter { + void operator()(miopenRNNDescriptor_t * desc) { + miopenDestroyRNNDescriptor(*desc); + delete desc; + } + }; +public: + + RNNDescriptor() {} + + RNNDescriptor(const int hsize, + const int nlayers, + miopenRNNInputMode_t inMode, + miopenRNNDirectionMode_t direction, + miopenRNNMode_t rnnMode, + miopenRNNBiasMode_t biasMode, + miopenRNNAlgo_t algo, + miopenDataType_t dataType) : + desc_(new miopenRNNDescriptor_t, RNNDescriptorDeleter()) + { + CHECK_MIOPEN_ERROR(miopenCreateRNNDescriptor(desc_.get())); + CHECK_MIOPEN_ERROR(miopenSetRNNDescriptor(*desc_, + hsize, + nlayers, + inMode, + direction, + rnnMode, + biasMode, + algo, + dataType)); + } + + miopenRNNDescriptor_t desc() const { return *desc_; }; +}; + diff --git a/code/amd/rnn_bench_rocm.cpp b/code/amd/rnn_bench_rocm.cpp new file mode 100644 index 0000000..30e7b80 --- /dev/null +++ b/code/amd/rnn_bench_rocm.cpp @@ -0,0 +1,300 @@ +#include +#include +#include +#include +#include + +#include +#include + +#include "tensor.h" +#include "miopen_helper.h" +#include "rnn_problems.h" + +template +class miopenRNN +{ + MIOpenHandle miopen_handle_; + RNNDescriptor rnnDesc_; + + int sequenceLen_; + + TensorDescriptorArray xDescArray_; + TensorDescriptorArray yDescArray_; + TensorDescriptor hxDesc_; + TensorDescriptor hyDesc_; + TensorDescriptor cxDesc_; + TensorDescriptor cyDesc_; + TensorDescriptor wDesc_; + + TensorDescriptorArray dxDescArray_; + TensorDescriptorArray dyDescArray_; + TensorDescriptor dhxDesc_; + TensorDescriptor dhyDesc_; + TensorDescriptor dcxDesc_; + TensorDescriptor dcyDesc_; + + size_t weight_size_byte_; + size_t workspace_size_byte_; + size_t trainspace_size_byte_; + + Tensor weights_; + Tensor workspace_; + Tensor trainspace_; + +public: + miopenRNN(int hidden_size, int batch_size, int time_steps, const std::string& rnn_type) : + sequenceLen_(time_steps), + xDescArray_ ({batch_size, hidden_size}, {hidden_size, 1}, time_steps), + yDescArray_ ({batch_size, hidden_size}, {hidden_size, 1}, time_steps), + dxDescArray_({batch_size, hidden_size}, {hidden_size, 1}, time_steps), + dyDescArray_({batch_size, hidden_size}, {hidden_size, 1}, time_steps), + hxDesc_ ({1, batch_size, hidden_size}, {hidden_size * batch_size, hidden_size, 1}), + hyDesc_ ({1, batch_size, hidden_size}, {hidden_size * batch_size, hidden_size, 1}), + dhxDesc_({1, batch_size, hidden_size}, {hidden_size * batch_size, hidden_size, 1}), + dhyDesc_({1, batch_size, hidden_size}, {hidden_size * batch_size, hidden_size, 1}), + cxDesc_ ({1, batch_size, hidden_size}, {hidden_size * batch_size, hidden_size, 1}), + cyDesc_ ({1, batch_size, hidden_size}, {hidden_size * batch_size, hidden_size, 1}), + dcxDesc_({1, batch_size, hidden_size}, {hidden_size * batch_size, hidden_size, 1}), + dcyDesc_({1, batch_size, hidden_size}, {hidden_size * batch_size, hidden_size, 1}) + { + miopenRNNMode_t rnn_mode; + if( rnn_type == "vanilla") + rnn_mode = miopenRNNRELU; + else if( rnn_type == "gru") + rnn_mode = miopenGRU; + else if( rnn_type == "lstm") + rnn_mode = miopenLSTM; + else + throw std::runtime_error("Unknow rnn mode in miopenRNN"); + + miopenDataType_t data_type; + if(std::is_same::value) + data_type = miopenFloat; + else + throw std::runtime_error("Unknow data type in miopenRNN"); + + rnnDesc_ = RNNDescriptor(hidden_size, + 1, + miopenRNNskip, + miopenRNNunidirection, + rnn_mode, + miopenRNNNoBias, + miopenRNNdefault, + data_type); + + CHECK_MIOPEN_ERROR(miopenGetRNNParamsDescriptor( miopen_handle_.handle(), + rnnDesc_.desc(), + xDescArray_.ptr()[0], + wDesc_.desc(), + data_type)); + + CHECK_MIOPEN_ERROR( miopenGetRNNParamsSize( miopen_handle_.handle(), + rnnDesc_.desc(), + xDescArray_.ptr()[0], + &weight_size_byte_, + data_type) ); + + + CHECK_MIOPEN_ERROR( miopenGetRNNWorkspaceSize( miopen_handle_.handle(), + rnnDesc_.desc(), + sequenceLen_, + xDescArray_.ptr(), + &workspace_size_byte_) ); + + CHECK_MIOPEN_ERROR( miopenGetRNNTrainingReserveSize(miopen_handle_.handle(), + rnnDesc_.desc(), + sequenceLen_, + xDescArray_.ptr(), + &trainspace_size_byte_) ); + + weights_ = rand(std::vector{static_cast(weight_size_byte_/sizeof(T))}); + workspace_ = zeros(std::vector{static_cast(workspace_size_byte_/sizeof(float))}); + trainspace_ = zeros(std::vector{static_cast(trainspace_size_byte_/sizeof(float))}); + + } + + void forward(Tensor x, Tensor hx, Tensor cx, + Tensor y, Tensor hy, Tensor cy) + { + CHECK_MIOPEN_ERROR(miopenRNNForwardTraining(miopen_handle_.handle(), + rnnDesc_.desc(), + sequenceLen_, + xDescArray_.ptr(), + (void *)x.begin(), + hxDesc_.desc(), + (void *)hx.begin(), + cxDesc_.desc(), + (void *)cx.begin(), + wDesc_.desc(), + (void *)weights_.begin(), + yDescArray_.ptr(), + (void *)y.begin(), + hyDesc_.desc(), + (void *)hy.begin(), + cyDesc_.desc(), + (void *)cy.begin(), + (void *)workspace_.begin(), + workspace_size_byte_, + (void *)trainspace_.begin(), + trainspace_size_byte_) ); + } + + void backward_data( Tensor y, Tensor dy, Tensor dhy, + Tensor dcy, Tensor hx, Tensor cx, + Tensor dx, Tensor dhx, Tensor dcx) { + CHECK_MIOPEN_ERROR(miopenRNNBackwardData(miopen_handle_.handle(), + rnnDesc_.desc(), + sequenceLen_, + yDescArray_.ptr(), + (void *)y.begin(), + dyDescArray_.ptr(), + (void *)dy.begin(), + dhyDesc_.desc(), + (void *)dhy.begin(), + dcyDesc_.desc(), + (void *)dcy.begin(), + wDesc_.desc(), + (void *)weights_.begin(), + hxDesc_.desc(), + (void *)hx.begin(), + cxDesc_.desc(), + (void *)cx.begin(), + dxDescArray_.ptr(), + (void *)dx.begin(), + dhxDesc_.desc(), + (void *)dhx.begin(), + dcxDesc_.desc(), + (void *)dcx.begin(), + (void *)workspace_.begin(), + workspace_size_byte_, + (void *)trainspace_.begin(), + trainspace_size_byte_) ); + } + +}; + +template +std::tuple time_rnn( + int hidden_size, int batch_size, int time_steps, const std::string & type, int inference) +{ + int num_repeats = 100; + + auto x = rand({time_steps, batch_size, hidden_size}); + auto y = rand({time_steps, batch_size, hidden_size}); + auto dx = rand({time_steps, batch_size, hidden_size}); + auto dy = rand({time_steps, batch_size, hidden_size}); + + auto hx = rand({1, batch_size, hidden_size}); + auto hy = rand({1, batch_size, hidden_size}); + auto cx = rand({1, batch_size, hidden_size}); + auto cy = rand({1, batch_size, hidden_size}); + auto dhx = rand({1, batch_size, hidden_size}); + auto dhy = rand({1, batch_size, hidden_size}); + auto dcx = rand({1, batch_size, hidden_size}); + auto dcy = rand({1, batch_size, hidden_size}); + + miopenRNN rnn(hidden_size, batch_size, time_steps, type); + + //Warm up + rnn.forward(x, hx, cx, y, hy, cy); + + hipDeviceSynchronize(); + auto start = std::chrono::steady_clock::now(); + + for (int i = 0; i < num_repeats; ++i) { + rnn.forward(x, hx, cx, y, hy, cy); + } + + hipDeviceSynchronize(); + + auto end = std::chrono::steady_clock::now(); + int fwd_time = static_cast(std::chrono::duration(end - start).count() / num_repeats); + + int bwd_inputs_time = 0; + int bwd_params_time = 0; + + if (!inference) + { + //Warm up + rnn.backward_data(y, dy, dhy, dcy, + hx, cx, dx, dhx, dcx); + + hipDeviceSynchronize(); + + start = std::chrono::steady_clock::now(); + + for (int i = 0; i < num_repeats; ++i) + { + rnn.backward_data(y, dy, dhy, dcy, + hx, cx, dx, dhx, dcx); + } + hipDeviceSynchronize(); + + end = std::chrono::steady_clock::now(); + bwd_inputs_time = std::chrono::duration(end - start).count() / num_repeats; + } + + return std::tuple(fwd_time, bwd_inputs_time, bwd_params_time); +} + + +int main(int argc, char **argv) { + + hipFree(0); + + int inference = 0; + + if (argc > 1) { + std::string inf = "inference"; + inference = argv[1] == inf ? 1 : 0; + } + + std::cout << std::setw(30) << "Times" << std::endl; + std::cout << std::setfill('-') << std::setw(190) << "-" << std::endl; + std::cout << std::setfill(' '); + std::cout << "hidden_size batch_size time_steps rnn_type fwd_time (usec) bwd_inputs_time (usec) bwd_params_time (usec) total_time (usec) " << std::endl; + std::cout << std::setfill('-') << std::setw(190) << "-" << std::endl; + std::cout << std::setfill(' '); + + int total_fwd_time=0, total_bwd_inputs_time=0, total_bwd_params_time=0; + for (const auto &problem : training_set) { + + int hidden_size, batch_size, time_steps; + std::string rnn_type; + std::tie(hidden_size, batch_size, time_steps, rnn_type) = problem; + + int fwd_time, bwd_inputs_time, bwd_params_time; + + std::tie(fwd_time, bwd_inputs_time, bwd_params_time) = + time_rnn(hidden_size, batch_size, time_steps, rnn_type, inference); + + std::cout << std::setw(5) << hidden_size; + std::cout << std::setw(15) << batch_size; + std::cout << std::setw(15) << time_steps; + std::cout << std::setw(19) << rnn_type; + std::cout << std::setw(11) << std::setprecision(7) << fwd_time; + std::cout << std::setw(24) << std::setprecision(7) << bwd_inputs_time; + std::cout << std::setw(24) << std::setprecision(7) << bwd_params_time; + std::cout << std::setw(19) << std::setprecision(8) << fwd_time + bwd_inputs_time + bwd_params_time; + + std::cout << std::endl; + + total_fwd_time += fwd_time; + total_bwd_inputs_time += bwd_inputs_time; + total_bwd_params_time += bwd_params_time; + + } + + std::cout << std::setw(82) << "Totals" ; + std::cout << std::setw(14) << std::setprecision(7) << total_fwd_time; + std::cout << std::setw(24) << std::setprecision(7) << total_bwd_inputs_time; + std::cout << std::setw(24) << std::setprecision(7) << total_bwd_params_time; + std::cout << std::setw(19) << std::setprecision(8) << total_fwd_time + total_bwd_inputs_time + total_bwd_params_time; + std::cout << std::endl; + + return 0; + +} + + diff --git a/code/amd/tensor.h b/code/amd/tensor.h new file mode 100644 index 0000000..d9f8ad7 --- /dev/null +++ b/code/amd/tensor.h @@ -0,0 +1,73 @@ +#pragma once + +#include +#include +#include +#include + +#include + +template +class Tensor { + std::vector dims_; + int size_; + + struct deleteDevPtr { + void operator()(T *p) const { + hipFree(p); + } + }; + + +public: + std::shared_ptr ptr_; + + Tensor() {} + + Tensor(std::vector dims) : dims_(dims) { + T* tmp_ptr; + size_ = std::accumulate(dims_.begin(), dims_.end(), 1, std::multiplies()); + hipMalloc(&tmp_ptr, sizeof(T) * size_); + + ptr_.reset(tmp_ptr, deleteDevPtr()); + } + + T* begin() const { return ptr_.get(); } + T* end() const { return ptr_.get() + size_; } + int size() const { return size_; } + std::vector dims() const { return dims_; } +}; + +template +Tensor fill(std::vector dims, T val) { + Tensor tensor(dims); + size_t d = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); + std::vector host_ptr(d); + std::fill(host_ptr.begin(), host_ptr.end(), val); + hipMemcpy(tensor.ptr_.get(), host_ptr.data(), d*sizeof(T), hipMemcpyHostToDevice); + return tensor; +} + +template +Tensor zeros(std::vector dims) +{ + Tensor tensor(dims); + size_t d = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); + hipMemset(tensor.ptr_.get(), 0, d*sizeof(T)); + return tensor; +} + +template +Tensor rand(std::vector dims) +{ + Tensor tensor(dims); + size_t d = std::accumulate(dims.begin(), dims.end(), 1, std::multiplies()); + std::vector host_ptr(d); + std::srand(std::time(0)); + for(int i=0;i