diff --git a/python/jittor/compile_extern.py b/python/jittor/compile_extern.py index 55944a72..4293dd9d 100644 --- a/python/jittor/compile_extern.py +++ b/python/jittor/compile_extern.py @@ -560,6 +560,59 @@ def inner(self, *args, **kw): if k == "mpi_test": continue setattr(core.Var, k, wrapper(mpi_ops.__dict__[k])) +def install_cutlass(root_folder): + url = "https://cloud.tsinghua.edu.cn/f/8fc42499904f43e39141/?dl=1" + + filename = "cutlass.zip" + fullname = os.path.join(root_folder, filename) + dirname = os.path.join(root_folder, "cutlass") + true_md5 = "41bb524a6bad4612d6017ed4b11f1d28" + + if os.path.exists(fullname): + md5 = run_cmd('md5sum '+fullname).split()[0] + if md5 != true_md5: + os.remove(fullname) + if os.path.isdir(dirname): + shutil.rmtree(dirname) + if not os.path.isdir(os.path.join(dirname, "include")): + if not os.path.isfile(os.path.join(root_folder, filename)): + LOG.i("Downloading cutlass...") + download_url_to_local(url, filename, root_folder, true_md5) + + if core.get_device_count() == 0: + return + shutil.unpack_archive(fullname, root_folder) + return dirname + +def setup_cutlass(): + use_cutlass = os.environ.get("use_cutlass", "0")=="1" + if not has_cuda: + use_cutlass = False + return + if not use_cutlass: return + cutlass_include_path = os.environ.get("cutlass_include_path") + print(cutlass_include_path) + if cutlass_include_path is None: + LOG.v("setup cutlass...") + from pathlib import Path + cutlass_path = os.path.join(str(Path.home()), ".cache", "jittor", "cutlass") + + make_cache_dir(cutlass_path) + cutlass_home = install_cutlass(cutlass_path) + if cutlass_home is None: return + os.environ['cutlass_include_path'] = cutlass_home + cutlass_include_path = os.path.join(cutlass_home, "include") + cutlass_tool_include_path = os.path.join(cutlass_home, "tools", "util", "include") + all_dir = f" -I\"{cutlass_include_path}\" -I\"{cutlass_tool_include_path}\"" + cutlass_src_dir = os.path.join(jittor_path, "extern", "cuda", "cutlass") + cutlass_src_files = [] + for r, _, f in os.walk(cutlass_src_dir): + for fname in f: + cutlass_src_files.append(os.path.join(r, fname)) + cutlass_ops = compile_custom_ops(cutlass_src_files, + extra_flags=f" {all_dir} ") + LOG.vv("Get cutlass_ops: "+str(dir(cutlass_ops))) + in_mpi = inside_mpi() FIX_TORCH_ERROR = 0 if os.name != 'nt' and not in_mpi: @@ -581,6 +634,7 @@ def inner(self, *args, **kw): setup_nccl() setup_cutt() +setup_cutlass() try: setup_mkl() diff --git a/python/jittor/cutlass_ops.py b/python/jittor/cutlass_ops.py new file mode 100644 index 00000000..abcf0048 --- /dev/null +++ b/python/jittor/cutlass_ops.py @@ -0,0 +1,1529 @@ +import jittor as jt +import numpy as np +from jittor import nn +import time +import sys +import os + +cutlass_path = os.environ.get('cutlass_include_path') +def depthwise_src_backward(x, weights): + cuda_header = ''' + #undef out + #include + + #include + #include + + #include + + #include + #include + + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include "executor.h" + + #include + + + // The code section below describes datatype for input, output tensors and + // computation between elements + using ElementAccumulator = float; // Data type of accumulator + using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) + using ElementSrc = float; // Data type of elements in src tensor + using ElementFilter = float; // Data type of elements in filter tensor + using ElementDst = float; // Data type of elements in output tensor + + using LayoutSrc = cutlass::layout::TensorNCHW; + using LayoutFilter = cutlass::layout::TensorNCHW; + using LayoutDst = cutlass::layout::TensorNCHW; + + // This code section describes whether you want to use tensor cores or regular + // SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassSimt; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm75; + + // This code section describes the tile size a thread block will compute + using ThreadblockShape = + cutlass::gemm::GemmShape<32, 32, 8>; // Threadblock tile shape + + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; // Warp tile shape + + // This code section describes the size of MMA op + using InstructionShape = + cutlass::gemm::GemmShape<1, 1, 1>; // TensorCore instruction shape + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::conv::threadblock::DepthwiseConvolutionDgradThreadblockSwizzle; + + // Number of pipelines you want to use + constexpr int NumStages = 2; + + // This code section describes the epilogue part of the kernel, we use default + // value + using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombination< + ElementDst, // Data type of output matrix. + 1, ElementAccumulator, // Data type of accumulator + ElementDst, // Data type of bias + ElementComputeEpilogue>; // Data type for alpha/beta in linear + // combination + using Convolution = cutlass::conv::device::Deconvolution< + ElementSrc, LayoutSrc, ElementFilter, LayoutFilter, ElementDst, + LayoutDst, ElementDst, LayoutDst, ElementDst, + cutlass::conv::ConvType::kDepthwiseConvolution, MMAOp, SmArch, + ThreadblockShape, WarpShape, InstructionShape, EpilogueOp, + SwizzleThreadBlock, NumStages, 1, 1, + cutlass::conv::SpecialOptimizeDesc::NONE, cutlass::arch::OpMultiplyAdd, + cutlass::conv::ImplicitGemmMode::GEMM_TN>; + + struct Options { + bool help; + cutlass::Tensor4DCoord input_size; + cutlass::Tensor4DCoord filter_size; + cutlass::Tensor4DCoord padding; + cutlass::MatrixCoord conv_stride; + cutlass::MatrixCoord dilation; + bool reference_check; + bool measure_performance; + int iterations; + bool save_workspace; + ElementComputeEpilogue alpha; + ElementComputeEpilogue beta; + bool benchmark; + std::string tag; + + Options() + : help(false), + input_size(1, 32, 32, 32), + filter_size(32, 3, 3, 1), + padding(1, 1, 1, 1), + conv_stride(1, 1), + dilation(1, 1), + reference_check(false), + measure_performance(true), + iterations(1000), + save_workspace(false), + alpha(1), + beta(0), + benchmark(false) {} + + // Verify the problem size is compatible with the CUTLASS Convolution + // implementation. + bool valid() { + int const kAlignment = 1; + + if ((input_size.c() % kAlignment) || (filter_size.n() % kAlignment)) { + // misaligned tensors + return false; + } + + // Invalid padding + if ((padding.h() != filter_size.h() / 2) || + (padding.w() != filter_size.w() / 2)) { + return false; + } + + return true; + } + + /// Updates input and filter sizes + void update(cutlass::Tensor4DCoord input_size, + cutlass::Tensor4DCoord filter_size) { + this->input_size = input_size; + this->filter_size = filter_size; + + padding.n() = filter_size.h() / 2; + padding.h() = filter_size.h() / 2; + padding.w() = filter_size.w() / 2; + padding.c() = filter_size.w() / 2; + } + + /// Computes the output tensor size (NPQK) + cutlass::Tensor4DCoord output_size() const { + return cutlass::Tensor4DCoord( + input_size.n(), + (input_size.h() + padding.n() + padding.h() - filter_size.h()) / + conv_stride.row() + + 1, + (input_size.w() + padding.w() + padding.c() - filter_size.w()) / + conv_stride.column() + + 1, + filter_size.n()); + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Number of multiply-adds = NPQK * CRS / K + int64_t fmas = + output_size().product() * + int64_t(filter_size.h() * filter_size.w() * filter_size.c()) / + output_size().c(); + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } + + }; + + #define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " \ + << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + #define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") + #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + + ''' + if x.dtype == jt.float16: + cuda_header = cuda_header.replace("float", "cutlass::half_t") + cuda_src = ''' + @alias(weights, in1) + @alias(x, in0) + @alias(dst, out0) + bool notSupported = false; + Options options = Options(); + options.update( + {x_shape0, x_shape2, x_shape3, x_shape1}, + {weights_shape0, weights_shape2, weights_shape3, 1}); + + cutlass::TensorRef d_src((ElementSrc*)x_p, + LayoutSrc().packed({options.input_size})); + cutlass::TensorRef d_filter((ElementFilter*)weights_p, + LayoutFilter().packed(options.filter_size)); + cutlass::TensorRef d_dst((ElementDst*)dst_p, LayoutDst().packed(options.output_size())); + cutlass::TensorRef d_bias = {nullptr, Convolution::LayoutDst()}; + cutlass::TensorRef d_z = {nullptr, Convolution::LayoutDst()}; + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + + int split_k_slices = 1; + typename Convolution::Arguments arguments{ + {options.input_size, options.filter_size, options.padding, + options.conv_stride, options.dilation, options.output_size(), mode, + split_k_slices, options.filter_size.n()}, + d_src, // tensor_src.device_ref(), + d_filter, // tensor_filter.device_ref(), + d_bias, // tensor_bias.device_ref(), + d_z, // tensor_z.device_ref(), + d_dst, // tensor_dst.device_ref(), + {options.alpha, 0, options.beta}}; + + Convolution conv_op; + + size_t workspace_size = conv_op.get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(conv_op.initialize(arguments, workspace.get())); + // auto temp1 = exe.alloc_temp(workspace_size); + // CUTLASS_CHECK(conv_op.initialize(arguments, temp1.ptr)); + { + //static SimpleProfiler _("aa"); + //SimpleProfilerGuard __(_); + + //ccccc c + CUTLASS_CHECK(conv_op()); + // kernel<<<1,1>>>(); + } + ''' + output = jt.zeros(x.shape) + output = jt.code([x, weights], [output], cuda_header=cuda_header, cuda_src=cuda_src)[0] + output.compile_options = {f"FLAGS: --expt-relaxed-constexpr -I{cutlass_path}/include -I{cutlass_path}/tools/util/include ": 1} + return output + +def depthwise_filter_backward(x, weights, diff): + cuda_header = ''' + #undef out + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + + #include "executor.h" + + #include + + // The code section below describes datatype for input, output tensors and + // computation between elements + using ElementAccumulator = float; // Data type of accumulator + using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) + using ElementSrc = float; // Data type of elements in src tensor + using ElementFilter = float; // Data type of elements in filter tensor + using ElementDst = float; // Data type of elements in output tensor + + using LayoutSrc = cutlass::layout::TensorNCHW; + using LayoutFilter = cutlass::layout::TensorNCHW; + using LayoutDst = cutlass::layout::TensorNCHW; + using LayoutGrad = cutlass::layout::TensorNCHW; + + // This code section describes whether you want to use tensor cores or regular + // SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassSimt; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm75; + + // This code section describes the tile size a thread block will compute + using ThreadblockShape = + cutlass::gemm::GemmShape<32, 32, 8>; // Threadblock tile shape + + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; // Warp tile shape + + // This code section describes the size of MMA op + using InstructionShape = + cutlass::gemm::GemmShape<1, 1, 1>; // TensorCore instruction shape + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::conv::threadblock::DepthwiseConvolutionWgradThreadblockSwizzle; + + // Number of pipelines you want to use + constexpr int NumStages = 2; + + // This code section describes the epilogue part of the kernel, we use default + // value + using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementFilter, // Data type of output matrix. + 1, ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue>; // Data type for alpha/beta in linear + // combination + using Convolution = cutlass::conv::device::ConvolutionBackwardFilter< + ElementSrc, LayoutSrc, ElementDst, LayoutDst, ElementFilter, + LayoutFilter, ElementFilter, + cutlass::conv::ConvType::kDepthwiseConvolution, MMAOp, SmArch, + ThreadblockShape, WarpShape, InstructionShape, EpilogueOp, + SwizzleThreadBlock, NumStages, 1, 1, + cutlass::conv::SpecialOptimizeDesc::NONE, cutlass::arch::OpMultiplyAdd, + cutlass::conv::ImplicitGemmMode::GEMM_NT>; + + struct Options { + bool help; + cutlass::Tensor4DCoord input_size; + cutlass::Tensor4DCoord filter_size; + cutlass::Tensor4DCoord padding; + cutlass::MatrixCoord conv_stride; + cutlass::MatrixCoord dilation; + bool reference_check; + bool measure_performance; + int iterations; + bool save_workspace; + ElementComputeEpilogue alpha; + ElementComputeEpilogue beta; + bool benchmark; + std::string tag; + + Options() + : help(false), + input_size(1, 32, 32, 32), + filter_size(32, 3, 3, 1), + padding(1, 1, 1, 1), + conv_stride(1, 1), + dilation(1, 1), + reference_check(false), + measure_performance(true), + iterations(1000), + save_workspace(false), + alpha(1), + beta(0), + benchmark(false) {} + + // Verify the problem size is compatible with the CUTLASS Convolution + // implementation. + bool valid() { + int const kAlignment = 1; + + if ((input_size.c() % kAlignment) || (filter_size.n() % kAlignment)) { + // misaligned tensors + return false; + } + + // Invalid padding + if ((padding.h() != filter_size.h() / 2) || + (padding.w() != filter_size.w() / 2)) { + return false; + } + + return true; + } + + /// Updates input and filter sizes + void update(cutlass::Tensor4DCoord input_size, + cutlass::Tensor4DCoord filter_size) { + this->input_size = input_size; + this->filter_size = filter_size; + + padding.n() = filter_size.h() / 2; + padding.h() = filter_size.h() / 2; + padding.w() = filter_size.w() / 2; + padding.c() = filter_size.w() / 2; + } + + /// Computes the output tensor size (NPQK) + cutlass::Tensor4DCoord output_size() const { + return cutlass::Tensor4DCoord( + input_size.n(), + (input_size.h() + padding.n() + padding.h() - filter_size.h()) / + conv_stride.row() + + 1, + (input_size.w() + padding.w() + padding.c() - filter_size.w()) / + conv_stride.column() + + 1, + filter_size.n()); + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Number of multiply-adds = NPQK * CRS / K + int64_t fmas = + output_size().product() * + int64_t(filter_size.h() * filter_size.w() * filter_size.c()) / + output_size().c(); + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } + + }; + + #define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " \ + << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + #define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + #define CHECK_CUDA(x) TORCH_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor") + #define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + #define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x) + ''' + if x.dtype == jt.float16: + cuda_header = cuda_header.replace("float", "cutlass::half_t") + cuda_src = ''' + @alias(grad, in2) + @alias(weights, in1) + @alias(x, in0) + @alias(dst, out0) + bool notSupported = false; + Options options = Options(); + options.update( + {x_shape0, x_shape2, x_shape3, x_shape1}, + {weights_shape0, weights_shape2, weights_shape3, 1}); + + cutlass::TensorRef d_src((ElementSrc*)x_p, + LayoutSrc().packed({options.input_size})); + cutlass::TensorRef d_diff((ElementFilter*)grad_p, + LayoutDst().packed(options.output_size())); + cutlass::TensorRef d_filter((ElementFilter*)dst_p, LayoutFilter().packed(options.filter_size)); + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + + int split_k_slices = 1; + typename Convolution::Arguments arguments{ + {options.input_size, options.filter_size, options.padding, + options.conv_stride, options.dilation, options.output_size(), mode, + split_k_slices, options.filter_size.n()}, + d_src, // tensor_src.device_ref(), + d_diff, // tensor_filter.device_ref(), + d_filter, + {options.alpha}}; + + Convolution conv_op; + + size_t workspace_size = conv_op.get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(conv_op.initialize(arguments, workspace.get())); + // auto temp1 = exe.alloc_temp(workspace_size); + // CUTLASS_CHECK(conv_op.initialize(arguments, temp1.ptr)); + { + //static SimpleProfiler _("aa"); + //SimpleProfilerGuard __(_); + + //ccccc c + CUTLASS_CHECK(conv_op()); + // kernel<<<1,1>>>(); + } + ''' + output = jt.zeros(weights.shape) + output = jt.code([x, weights, diff], [output], cuda_header=cuda_header, cuda_src=cuda_src)[0] + output.compile_options = {f"FLAGS: --expt-relaxed-constexpr -I{cutlass_path}/include -I{cutlass_path}/tools/util/include ": 1} + return output + +class DepthwiseConv(jt.Function): + def __init__(self, stride=1, padding=0, dilation=1): + self.stride = stride if isinstance(stride, tuple) else (stride, stride) + self.padding = padding if isinstance(padding, tuple) else (padding, padding) + self.dilation = dilation if isinstance(dilation, tuple) else (dilation, dilation) + self.x = None + self.weights = None + + def execute(self, x, weights): + self.x = x + self.weights = weights + cuda_header = ''' + #undef out + #include + #include + #include + #include + + #include + #include + + #include + + #include + #include + + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include "executor.h" + + // The code section below describes datatype for input, output tensors and + // computation between elements + using ElementAccumulator = float; // Data type of accumulator + using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) + using ElementSrc = float; // Data type of elements in src tensor + using ElementFilter = float; // Data type of elements in filter tensor + using ElementDst = float; // Data type of elements in output tensor + + using LayoutSrc = cutlass::layout::TensorNCHW; + using LayoutFilter = cutlass::layout::TensorNCHW; + using LayoutDst = cutlass::layout::TensorNCHW; + + // This code section describes whether you want to use tensor cores or regular + // SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassSimt; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm75; + + // This code section describes the tile size a thread block will compute + using ThreadblockShape = + cutlass::gemm::GemmShape<32, 32, 8>; // Threadblock tile shape + + // This code section describes tile size a warp will compute + using WarpShape = cutlass::gemm::GemmShape<8, 16, 8>; // Warp tile shape + + // This code section describes the size of MMA op + using InstructionShape = + cutlass::gemm::GemmShape<1, 1, 1>; // TensorCore instruction shape + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = + cutlass::conv::threadblock::DepthwiseConvolutionFpropThreadblockSwizzle; + + // Number of pipelines you want to use + constexpr int NumStages = 1; + + // This code section describes the epilogue part of the kernel, we use default + // value + using EpilogueOp = cutlass::epilogue::thread::BiasAddLinearCombination< + ElementDst, // Data type of output matrix. + 1, ElementAccumulator, // Data type of accumulator + ElementDst, // Data type of bias + ElementComputeEpilogue>; // Data type for alpha/beta in linear + // combination + using Convolution = cutlass::conv::device::Convolution< + ElementSrc, LayoutSrc, ElementFilter, LayoutFilter, ElementDst, + LayoutDst, ElementDst, LayoutDst, ElementDst, + cutlass::conv::ConvType::kDepthwiseConvolution, MMAOp, SmArch, + ThreadblockShape, WarpShape, InstructionShape, EpilogueOp, + SwizzleThreadBlock, NumStages, 1, 1, + cutlass::conv::SpecialOptimizeDesc::NONE, cutlass::arch::OpMultiplyAdd, + cutlass::conv::ImplicitGemmMode::GEMM_TN>; + + struct Options { + bool help; + cutlass::Tensor4DCoord input_size; + cutlass::Tensor4DCoord filter_size; + cutlass::Tensor4DCoord padding; + cutlass::MatrixCoord conv_stride; + cutlass::MatrixCoord dilation; + bool reference_check; + bool measure_performance; + int iterations; + bool save_workspace; + ElementComputeEpilogue alpha; + ElementComputeEpilogue beta; + bool benchmark; + std::string tag; + + Options() + : help(false), + input_size(1, 32, 32, 32), + filter_size(32, 3, 3, 1), + padding(1, 1, 1, 1), + conv_stride(1, 1), + dilation(1, 1), + reference_check(false), + measure_performance(false), + iterations(1000), + save_workspace(false), + alpha(1), + beta(0), + benchmark(false) {} + + // Verify the problem size is compatible with the CUTLASS Convolution + // implementation. + bool valid() { + int const kAlignment = 1; + + if ((input_size.c() % kAlignment) || (filter_size.n() % kAlignment)) { + // misaligned tensors + return false; + } + + // Invalid padding + if ((padding.h() != filter_size.h() / 2) || + (padding.w() != filter_size.w() / 2)) { + return false; + } + + return true; + } + + /// Updates input and filter sizes + void update(cutlass::Tensor4DCoord input_size, + cutlass::Tensor4DCoord filter_size) { + this->input_size = input_size; + this->filter_size = filter_size; + + padding.n() = filter_size.h() / 2; + padding.h() = filter_size.h() / 2; + padding.w() = filter_size.w() / 2; + padding.c() = filter_size.w() / 2; + } + + /// Computes the output tensor size (NPQK) + cutlass::Tensor4DCoord output_size() const { + return cutlass::Tensor4DCoord( + input_size.n(), + (input_size.h() + padding.n() + padding.h() - filter_size.h()) / + conv_stride.row() + + 1, + (input_size.w() + padding.w() + padding.c() - filter_size.w()) / + conv_stride.column() + + 1, + filter_size.n()); + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Number of multiply-adds = NPQK * CRS / K + int64_t fmas = + output_size().product() * + int64_t(filter_size.h() * filter_size.w() * filter_size.c()) / + output_size().c(); + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } + + }; + + #define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " \ + << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + #define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + ''' + if x.dtype == jt.float16: + cuda_header.replace("float", "cutlass::half_t") + cuda_src = ''' + // __global__ void kernel() {} + @alias(weights, in1) + @alias(x, in0) + @alias(dst, out0) + bool notSupported = false; + Options options = Options(); + options.update( + {x_shape0, x_shape2, x_shape3, x_shape1}, + {weights_shape0, weights_shape2, weights_shape3, 1}); + + cutlass::TensorRef d_src((ElementSrc*)x_p, + LayoutSrc().packed({options.input_size})); + cutlass::TensorRef d_filter((ElementFilter*)weights_p, + LayoutFilter().packed(options.filter_size)); + cutlass::TensorRef d_dst((ElementDst*)dst_p, LayoutDst().packed(options.output_size())); + cutlass::TensorRef d_bias = {nullptr, Convolution::LayoutDst()}; + cutlass::TensorRef d_z = {nullptr, Convolution::LayoutDst()}; + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + + int split_k_slices = 1; + typename Convolution::Arguments arguments{ + {options.input_size, options.filter_size, options.padding, + options.conv_stride, options.dilation, options.output_size(), mode, + split_k_slices, options.filter_size.n()}, + d_src, // tensor_src.device_ref(), + d_filter, // tensor_filter.device_ref(), + d_bias, // tensor_bias.device_ref(), + d_z, // tensor_z.device_ref(), + d_dst, // tensor_dst.device_ref(), + {options.alpha, 0, options.beta}}; + Convolution conv_op; + + size_t workspace_size = conv_op.get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(conv_op.initialize(arguments, workspace.get())); + // auto temp1 = exe.alloc_temp(workspace_size); + // CUTLASS_CHECK(conv_op.initialize(arguments, temp1.ptr)); + { + //static SimpleProfiler _("aa"); + //SimpleProfilerGuard __(_); + + //ccccc c + CUTLASS_CHECK(conv_op()); + // kernel<<<1,1>>>(); + } + ''' + output = jt.zeros(x.shape) + output = jt.code([x, weights], [output], cuda_header=cuda_header, cuda_src=cuda_src)[0] + output.compile_options = {f"FLAGS: --expt-relaxed-constexpr -I{cutlass_path}/include -I{cutlass_path}/tools/util/include ": 1} + return output + + def grad(self, g): + return depthwise_src_backward(g, self.weights), depthwise_filter_backward(self.x, self.weights, g) + +def backward_header(use_fp16 = False): + cuda_header = ''' + #pragma once + #undef out + #include + + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + + // implement temp GPUMatrix and GPUDynamicMatrix here. + // #define RM MatrixLayout::kRowMajor + // #define CM MatrixLayout::kColumnMajor + // enum class MatrixLayout { + // kColumnMajor, + // kRowMajor + // }; + + enum class Activation { + ReLU, + Exponential, + Sine, + Sigmoid, + Squareplus, + Softplus, + None, + }; + + #define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + #define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + using SmArch = cutlass::arch::Sm80; + + using TypeAccumulator = float; + using TypeCompute = float; + using ElementComputeEpilogue = TypeAccumulator; // <- data type of epilogue operations + using MMAOp = cutlass::arch::OpClassTensorOp; + + // using ShapeMMAOp = typename std::conditional< + // std::is_same, cutlass::arch::OpClassTensorOp>::value, + // typename std::conditional< + // std::is_same::value || std::is_same::value, + // cutlass::gemm::GemmShape<16, 8, 8>, + // cutlass::gemm::GemmShape<8, 8, 4> + // >::type, + // cutlass::gemm::GemmShape<1, 1, 1> + // >::type; + + // using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; + template + struct LayerConfig { + using k_thread_block = thread_block; + using k_warp = warp; + }; + + using FullLayerK = LayerConfig, cutlass::gemm::GemmShape<32, 32, 32>>; + using LastLayerK = LayerConfig, cutlass::gemm::GemmShape<32, 32, 32>>; + + // using FullLayer = typename std::conditional< + // std::is_same, cutlass::arch::OpClassSimt>::value, + // LayerConfig, cutlass::gemm::GemmShape<32, 64, 8>>, + // LayerConfig, cutlass::gemm::GemmShape<64, 64, 32>> + // >::type; + + // using FullLayerPreReLU = typename std::conditional< + // std::is_same, cutlass::arch::OpClassSimt>::value, + // LayerConfig, cutlass::gemm::GemmShape<32, 64, 8, true>>, + // LayerConfig, cutlass::gemm::GemmShape<64, 64, 32, true>> + // >::type; + + // using LastLayer = typename std::conditional< + // std::is_same, cutlass::arch::OpClassSimt>::value, + // LayerConfig, cutlass::gemm::GemmShape<32, 64, 8>>, + // typename std::conditional< + // std::is_same::value || std::is_same::value, + // LayerConfig, cutlass::gemm::GemmShape<32, 32, 32>>, + // LayerConfig, cutlass::gemm::GemmShape<32, 32, 32>> + // >::type + // >::type; + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + // warp activation defined here + template + __host__ __device__ void warp_activation(Activation activation, const fragment_t& frag, fragment_t& result) { + switch (activation) { + case Activation::ReLU: + CUTLASS_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)((T)frag.x[t] > (T)0.0f); + } + return; + case Activation::None: result = frag; return; + default: + // Unsupported activation + // assert(false); // Commented out due to isolated strange side-effects on Windows + return; + } + } + + template + __host__ __device__ fragment_t warp_activation(Activation activation, const fragment_t& frag) { + fragment_t result; + warp_activation(activation, frag, result); + return result; + } + + + template + __host__ __device__ void warp_activation_backward(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag, fragment_t& result) { + switch (activation) { + case Activation::ReLU: + CUTLASS_PRAGMA_UNROLL + for (int t=0; t < result.num_elements; t++) { + result.x[t] = frag.x[t] * (T)(forward_frag.x[t] > (T)0.0f); + } + return; + case Activation::None: result = frag; return; + default: + // Unsupported activation + // assert(false); // Commented out due to isolated strange side-effects on Windows + return; + } + } + + template + __host__ __device__ fragment_t warp_activation_backward(Activation activation, const fragment_t& frag, const forward_fragment_t& forward_frag) { + fragment_t result; + warp_activation_backward(activation, frag, forward_frag, result); + return result; + } + + // // This code section describes the epilogue part of the kernel + + template + struct CutlassFragmentWrapper { + static const uint32_t num_elements = V::kElements; + V x; + }; + + template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + cutlass::FloatRoundStyle Round = cutlass::FloatRoundStyle::round_to_nearest + > + class ActivationEpilogue { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = cutlass::Array; + using FragmentAccumulator = cutlass::Array; + using ComputeFragment = cutlass::Array; + + static cutlass::FloatRoundStyle const kRound = Round; + + struct Params { + Activation activation; + bool sum_source; + }; + + public: + CUTLASS_HOST_DEVICE + ActivationEpilogue(Params const ¶ms) : m_activation{params.activation}, m_sum_source{params.sum_source} { } + + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return m_sum_source; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { } + + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator) const { + cutlass::NumericArrayConverter accumulator_converter; + + auto intermediate = CutlassFragmentWrapper{accumulator_converter(accumulator)}; + intermediate = warp_activation(m_activation, intermediate); + + cutlass::NumericArrayConverter destination_converter; + return destination_converter(intermediate.x); + } + + CUTLASS_HOST_DEVICE + FragmentOutput operator()(FragmentAccumulator const &accumulator, FragmentOutput const &source) const { + cutlass::NumericArrayConverter source_converter; + cutlass::NumericArrayConverter accumulator_converter; + + cutlass::plus plus_op; + auto intermediate = CutlassFragmentWrapper{accumulator_converter(accumulator)}; + if (m_sum_source) { + intermediate.x = plus_op(intermediate.x, source_converter(source)); + } + intermediate = warp_activation(m_activation, intermediate); + + cutlass::NumericArrayConverter destination_converter; + return destination_converter(intermediate.x); + } + + private: + Activation m_activation; + bool m_sum_source; + }; + + template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + cutlass::FloatRoundStyle Round = cutlass::FloatRoundStyle::round_to_nearest + > + class ActivationTransferEpilogue { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = cutlass::Array; + using FragmentAccumulator = cutlass::Array; + using ComputeFragment = cutlass::Array; + + static cutlass::FloatRoundStyle const kRound = Round; + + /// Host-constructable parameters structure + struct Params { + Activation activation; + }; + + public: + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + ActivationTransferEpilogue(Params const ¶ms) : m_activation{params.activation} { } + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { } + + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator, + FragmentOutput const &source) const { + + cutlass::NumericArrayConverter source_converter; + cutlass::NumericArrayConverter accumulator_converter; + + auto converted_source = CutlassFragmentWrapper{source_converter(source)}; + auto intermediate = CutlassFragmentWrapper{accumulator_converter(accumulator)}; + + intermediate = warp_activation_backward(m_activation, intermediate, converted_source); + + cutlass::NumericArrayConverter destination_converter; + return destination_converter(intermediate.x); + } + + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &accumulator) const { + + cutlass::NumericArrayConverter accumulator_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + cutlass::NumericArrayConverter destination_converter; + + return destination_converter(converted_accumulator); + } + + private: + Activation m_activation; + }; + + + // template + // static constexpr int n_vectorized_elements = std::is_same, cutlass::arch::OpClassTensorOp>::value ? (128 / cutlass::sizeof_bits::value) : 1; + + // template + // using SumOp = cutlass::epilogue::thread::LinearCombination, TypeAccumulator, TypeCompute>; + + // template + // using IntermediateActivationOp = ActivationEpilogue; + + // template + // using IntermediateActivationTransferOp = ActivationTransferEpilogue; + + // template + // using ActivationOp = ActivationEpilogue, TypeAccumulator, TypeCompute>; + + // template + // using ActivationTransferOp = ActivationTransferEpilogue, TypeAccumulator, TypeCompute>; + + using OurGemm = cutlass::gemm::device::Gemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + TypeAccumulator, + MMAOp, + SmArch, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + ActivationEpilogue, + SwizzleThreadBlock, + 2 + >; + + using OurGemmTransfer = cutlass::gemm::device::Gemm< + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + TypeAccumulator, + MMAOp, + SmArch, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 32, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + ActivationTransferEpilogue, + SwizzleThreadBlock, + 2 + >; + + // using epi = cutlass::epilogue::thread::LinearCombination; + // using SplitKGemm = cutlass::gemm::device::GemmSplitKParallel< + // float, + // cutlass::layout::ColumnMajor, + // float, + // cutlass::layout::RowMajor, + // float, + // cutlass::layout::RowMajor, + // TypeAccumulator, + // MMAOp, + // SmArch, + // cutlass::gemm::GemmShape<128, 128, 32>, + // cutlass::gemm::GemmShape<64, 32, 32>, + // cutlass::gemm::GemmShape<16, 8, 8>, + // epi + // >; + + using OurGemmW = cutlass::gemm::device::Gemm< + float, + cutlass::layout::ColumnMajor, + float, + cutlass::layout::RowMajor, + float, + cutlass::layout::RowMajor, + TypeAccumulator, + MMAOp, + SmArch, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 8>, + ActivationEpilogue, + SwizzleThreadBlock, + 2 + >; + + void backward(float* input, float* grad, float* output, float* output_grad, int input_m, int input_n, int grad_m, int grad_n, int output_m, int output_n) { // grad * weight.T + using Gemm = OurGemmTransfer; + const int lda = grad_n; + const int ldb = input_n; + const int ldc = output_n; + const int ldd = output_n; + typename Gemm::Arguments arguments{ + {grad_m, input_m, grad_n}, // TODO + {grad, lda}, + {input, ldb}, + {output, ldc}, + {output_grad, ldc}, + {Activation::ReLU}, + 1 + }; + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Initialize CUTLASS kernel with arguments and workspace pointer + cutlass::device_memory::allocation workspace(workspace_size); + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(0); + CUTLASS_CHECK(status); + } + + void last_inp_backward(float* input, float* grad, float* output, int input_m, int input_n, int grad_m, int grad_n, int output_m, int output_n) { // output * weight.T + using Gemm = OurGemm; + const int lda = grad_n; + const int ldb = input_n; + const int ldc = output_n; + const int ldd = output_n; + typename Gemm::Arguments arguments{ + {grad_m, input_m, grad_n}, // TODO + {grad, lda}, + {input, ldb}, + {output, ldc}, + {output, ldc}, + {Activation::None, false}, + 1 + }; + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Initialize CUTLASS kernel with arguments and workspace pointer + cutlass::device_memory::allocation workspace(workspace_size); + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(); + CUTLASS_CHECK(status); + } + + void weight_backward(float* input, float* grad, float* weight_grad, int input_m, int input_n, int grad_m, int grad_n, int weight_grad_m, int weight_grad_n) { // A.T * GRAD + + int batch_size = grad_n; + + using Gemm = OurGemmW; + const int lda = input_n; + const int ldb = grad_n; + const int ldc = weight_grad_n; + const int ldd = weight_grad_n; + typename Gemm::Arguments arguments{ + {input_n, grad_n, grad_m}, // TODO + {input, lda}, + {grad, ldb}, + {weight_grad, ldc}, + {weight_grad, ldc}, + {Activation::None, false}, + 1 + }; + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Initialize CUTLASS kernel with arguments and workspace pointer + cutlass::device_memory::allocation workspace(workspace_size); + cutlass::Status status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(); + CUTLASS_CHECK(status); + } + ''' + if use_fp16: + cuda_header = cuda_header.replace("float", "cutlass::half_t") + return cuda_header + +class FullyFusedMlp(jt.Function): + def __init__(self): + self.input = None + self.outputs = [] + self.shapes = [] + self.dtypes = [] + self.weights_grad = [] + self.max_dim = 0 + self.weights = None + + def single_forward(self, a, b): + cuda_header = ''' + #undef out + #include + #include + + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include + #include "executor.h" + + #define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + if (error != cutlass::Status::kSuccess) { \ + std::cerr << "Got cutlass error: " \ + << cutlassGetStatusString(error) << " at: " << __LINE__ \ + << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + #define CUDA_CHECK(status) \ + { \ + cudaError_t error = status; \ + if (error != cudaSuccess) { \ + std::cerr << "Got bad cuda status: " << cudaGetErrorString(error) \ + << " at line: " << __LINE__ << std::endl; \ + exit(EXIT_FAILURE); \ + } \ + } + + // The code section below describes datatype for input, output matrices and computation between + // elements in input matrices. + using ElementAccumulator = float; // <- data type of accumulator + using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations + using ElementInputA = float; // <- data type of elements in input matrix A + using ElementInputB = float; // <- data type of elements in input matrix B + using ElementOutput = float; // <- data type of elements in output matrix D + + // The code section below describes matrix layout of input and output matrices. + // Column Major for Matrix A, B and C. + // + // Note this example only works for ColumnMajor output because + // 1) we only have row major epilogue. + // 2) we swap A and B if the output is column major then we can still use the + // row major epilogue. + // 3) Mx1 bias vector becomes 1xM after the swapping/transposing. + // 4) we can use the existing OutputIterator to load 1xM bias vector. + + using LayoutInputA = cutlass::layout::RowMajor; + using LayoutInputB = cutlass::layout::RowMajor; + using LayoutOutput = cutlass::layout::RowMajor; + + // This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM + using MMAOp = cutlass::arch::OpClassTensorOp; + + // This code section describes CUDA SM architecture number + using SmArch = cutlass::arch::Sm75; + + // This code section describes the tile size a thread block will compute + using ShapeMMAThreadBlock = + cutlass::gemm::GemmShape<128, 128, 32>; // <- threadblock tile M = 128, N = 128, K = 32 + // This code section describes tile size a warp will compute + using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 + // This code section describes the size of MMA op + using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 8, N = 8, K = 4 + + // This code section describes how threadblocks are scheduled on GPU + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? + + // Define the epilogue operation as LinearCombinationRelu. This is approximately equal to + // + // d_ij = max(0, alpha * sum_k(a_ik * b_kj) + c_ij ) + // + using EpilogueOp = cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- this is the number of elements per + // vectorized memory access. For half + // precision, it's 8 elements. This becomes + // the vector width of math instructions in + // epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue, // <- data type for alpha in linear combination function + cutlass::epilogue::thread::ScaleType::NoBetaScaling>; // <- alpha x C + bias + + // Number of pipelines you want to use + constexpr int NumStages = 2; + + using Gemm = cutlass::gemm::device::Gemm; + ''' + if a.dtype == jt.float16: + cuda_header = cuda_header.replace("float", "cutlass::half_t") + cuda_src = ''' + @alias(b, in1) + @alias(a, in0) + @alias(c, out0) + const int length_m = a_shape0; + const int length_n = b_shape1; + const int length_k = a_shape1; + + // Create a tuple of problem size for matrix multiplication + cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); + + // Initialize tensors using CUTLASS helper functions + cutlass::TensorRef tensor_a((ElementInputA*)a_p, + LayoutInputA().packed(problem_size.mk())); + cutlass::TensorRef tensor_b((ElementInputB*)b_p, + LayoutInputB().packed(problem_size.kn())); + cutlass::TensorRef tensor_d((ElementOutput*)c_p, + LayoutOutput().packed(problem_size.mn())); + + // Initialize alpha for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + + // Split K dimension into 1 partitions + int split_k_slices = 1; + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{ + problem_size, // <- problem size of matrix multiplication + tensor_a, // <- reference to matrix A on device + tensor_b, // <- reference to matrix B on device + + {NULL, 0}, // <- the C matrix is treated as the bias vector. We can enable the GEMM + // to project away the N dimension by setting the stride to zero. + + tensor_d, // <- reference to matrix D on device, + {alpha}, // <- alpha + split_k_slices}; // <- k-dimension split factor + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + + // Allocate workspace memory + auto temp1 = exe.alloc_temp(workspace_size); + + // Allocate workspace memory + // cutlass::device_memory::allocation workspace(workspace_size); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, temp1.ptr); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(); + CUTLASS_CHECK(status); + ''' + output = jt.code((a.shape[0], b.shape[1]), a.dtype, [a, b], cuda_header=cuda_header, cuda_src=cuda_src) + output.compile_options = {f"FLAGS: --expt-relaxed-constexpr -I{cutlass_path}/include -I{cutlass_path}/tools/util/include ": 1} + return output + + def execute(self, a, *args): + self.shapes = [] + self.dtypes = [] + self.max_dim = 0 + self.weights = list(args) + self.input = a + weights = args + for i in range(len(weights)): + self.outputs.append(self.single_forward(a, weights[i])) + a = self.outputs[-1] + # print(self.outputs) + return self.outputs[-1] + + def backward(self, grad, weight, output): + use_fp16 = True if grad.dtype == jt.float16 else False + cuda_header = backward_header(use_fp16) + converter = "(cutlass::half_t*)" if use_fp16 else "" + cuda_src = f''' + @alias(input, in0) + @alias(grad, in1) + @alias(weight, in2) + @alias(weight_grad, out0) + @alias(inp_grad, out1) + weight_backward({converter}input_p, {converter}grad_p, {converter}weight_grad_p, input_shape0, input_shape1, grad_shape0, grad_shape1, weight_shape0, weight_shape1); + backward({converter}weight_p, {converter}grad_p, {converter}input_p, {converter}inp_grad_p, weight_shape0, weight_shape1, grad_shape0, grad_shape1, input_shape0, input_shape1); + ''' + weight_grad, out_grad = jt.code([weight.shape, output.shape], [weight.dtype, output.dtype], [output, grad, weight], cuda_header=cuda_header, cuda_src=cuda_src) + weight_grad.compile_options = {f"FLAGS: --expt-relaxed-constexpr -I{cutlass_path}/include -I{cutlass_path}/tools/util/include ": 1} + return out_grad, weight_grad + + def last_backward(self, grad, weight, output): + use_fp16 = True if grad.dtype == jt.float16 else False + cuda_header = backward_header(use_fp16) + converter = "(cutlass::half_t*)" if use_fp16 else "" + cuda_src = f''' + @alias(input, in0) + @alias(grad, in1) + @alias(weight, in2) + @alias(weight_grad, out0) + @alias(out_grad, out1) + weight_backward({converter}input_p, {converter}grad_p, {converter}weight_grad_p, input_shape0, input_shape1, grad_shape0, grad_shape1, weight_shape0, weight_shape1); + last_inp_backward({converter}weight_p, {converter}grad_p, {converter}out_grad_p, weight_shape0, weight_shape1, grad_shape0, grad_shape1, input_shape0, input_shape1); + ''' + weight_grad, out_grad = jt.code([weight.shape, output.shape], [weight.dtype, output.dtype], [output, grad, weight], cuda_header=cuda_header, cuda_src=cuda_src) + weight_grad.compile_options = {f"FLAGS: --expt-relaxed-constexpr -I{cutlass_path}/include -I{cutlass_path}/tools/util/include ": 1} + return out_grad, weight_grad + + def grad(self, grads): + self.weights_grad = [] + output = self.outputs[-1] + grads[output == 0] = 0 + num_hidden = len(self.weights)-1 + for idx in range(num_hidden, -1, -1): + if idx == 0: + grads, weight_grad = self.last_backward(grads, self.weights[0], self.input) + self.weights_grad.insert(0, weight_grad) + else: + grads, weight_grad = self.backward(grads, self.weights[idx], self.outputs[idx - 1]) + self.weights_grad.insert(0, weight_grad) + return (grads, *self.weights_grad) \ No newline at end of file diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc index 46d9fc23..87c432c0 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc +++ b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.cc @@ -11,6 +11,7 @@ #include "var.h" #include "cublas_matmul_op.h" #include "cublas_wrapper.h" +#include "ops/op_register.h" using namespace std; @@ -49,6 +50,27 @@ void CublasMatmulOp::infer_shape() { c->set_shape({n, k}); } +static auto make_cublas_matmul = get_op_info("cublas_matmul") + .get_constructor(); + +VarPtr CublasMatmulOp::grad(Var* out, Var* dout, Var* v, int v_index) { + // a [b,n,m] b [b,m,k], c[b,n,k] + // c = a*b + if (v_index == 0) { + if (trans_a) + return make_cublas_matmul(b, dout, trans_b, 1); + else + // da = dc*b^T + return make_cublas_matmul(dout, b, 0, trans_b^1); + } else { + if (trans_b) + return make_cublas_matmul(dout, a, 1, trans_a); + else + // db = a^T*dc + return make_cublas_matmul(a, dout, trans_a^1, 0); + } +} + void CublasMatmulOp::jit_prepare(JK& jk) { jk << "«T:" << a->dtype(); jk << "«Trans_a:" << (trans_a ? 'T' : 'N'); diff --git a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.h b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.h index 24c5d1b8..c7435d87 100644 --- a/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.h +++ b/python/jittor/extern/cuda/cublas/ops/cublas_matmul_op.h @@ -19,6 +19,7 @@ struct CublasMatmulOp : Op { const char* name() const override { return "cublas_matmul"; } void infer_shape() override; + VarPtr grad(Var* out, Var* dout, Var* v, int v_index) override; DECLARE_jit_run; }; diff --git a/python/jittor/extern/cuda/cutlass/ops/cutlass_matmul_op.cc b/python/jittor/extern/cuda/cutlass/ops/cutlass_matmul_op.cc new file mode 100644 index 00000000..61b50486 --- /dev/null +++ b/python/jittor/extern/cuda/cutlass/ops/cutlass_matmul_op.cc @@ -0,0 +1,134 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Guowei Yang <471184555@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** + +#ifdef JIT +#include +#endif +#include "var.h" +#include "cutlass_matmul_op.h" + +using namespace std; + +namespace jittor { + +extern int use_tensorcore; + +#ifndef JIT + +CutlassMatmulOp::CutlassMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b) + : a(a), b(b), trans_a(trans_a), trans_b(trans_b) { + flags.set(NodeFlags::_cuda, 1); + flags.set(NodeFlags::_cpu, 0); + // TODO: support int8 * int8 + ASSERT(a->dtype().is_float() && b->dtype().is_float()) << "type of two inputs should be the same"; + // TODO: support diffrent input type + ASSERT(a->dtype().dsize() == b->dtype().dsize()) << "type of two inputs should be the same"; + c = create_output(nullptr, a->dtype()); +} + +void CutlassMatmulOp::infer_shape() { + ASSERTop(a->shape.size(),==,2); + ASSERTop(b->shape.size(),==,2); + int n = a->shape[0], m = a->shape[1]; + int m_ = b->shape[0], k = b->shape[1]; + if (trans_a) { + swap(n, m); + } + if (trans_b) { + swap(m_, k); + } + ASSERTop(m,==,m_); + c->set_shape({n, k}); +} + +void CutlassMatmulOp::jit_prepare(JK& jk) { + jk << _CS("[T:") << a->dtype(); + jk << _CS("][Trans_a:") << (trans_a ? 'T' : 'N'); + jk << _CS("][Trans_b:") << (trans_b ? 'T' : 'N'); + jk << _CS("][op:") << (a->dtype().dsize() == 4 ? 'S' : 'D'); + jk << ']'; +} + +#else // JIT +#pragma clang diagnostic ignored "-Wtautological-compare" + + +static inline cudaError_t CutlassSgemmNN( + int M, + int N, + int K, + float alpha, + float const *A, + int lda, + float const *B, + int ldb, + float beta, + float *C, + int ldc) { + + // Define type definition for single-precision CUTLASS GEMM with column-major + // input matrices and 128x128x8 threadblock tile size (chosen by default). + // + // To keep the interface manageable, several helpers are defined for plausible compositions + // including the following example for single-precision GEMM. Typical values are used as + // default template arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for more details. + // + // To view the full gemm device API interface, see `cutlass/gemm/device/gemm.h` + + using ColumnMajor = cutlass::layout::ColumnMajor; + using RowMajor = cutlass::layout::RowMajor; + using CutlassGemm = cutlass::gemm::device::Gemm; // Layout of C matrix + + CutlassGemm gemm_operator; + + CutlassGemm::Arguments args({M, N, K}, // Gemm Problem dimensions + {A, lda}, // Tensor-ref for source matrix A + {B, ldb}, // Tensor-ref for source matrix B + {C, ldc}, // Tensor-ref for source matrix C + {C, ldc}, // Tensor-ref for destination matrix D (may be different memory than source C matrix) + {alpha, beta}); // Scalars used in the Epilogue + + cutlass::Status status = gemm_operator(args); + if (status != cutlass::Status::kSuccess) { + return cudaErrorUnknown; + } + return cudaSuccess; +} + +void CutlassMatmulOp::jit_run() { + const T alpha = 1.0f; + const T beta = 0.0f; + LOGi << "herher"; + const auto& as = a->shape; + const auto& bs = b->shape; + auto n = as[0]; + auto m = as[1]; + auto k = bs[1]; + if ('@Trans_a'=='T') { + n = as[1]; + m = as[0]; + } + if ('@Trans_b'=='T') { + k = bs[0]; + } + using ColumnMajor = cutlass::layout::ColumnMajor; + using RowMajor = cutlass::layout::RowMajor; + // a: [n,m], b: [m,k], c: [n,k] + cudaError_t result = CutlassSgemmNN(n, k, m, alpha, a->ptr(), '@Trans_a' == 'N' ? m : n, b->ptr(), '@Trans_b' == 'N' ? k : m, beta, c->ptr(), k); + assert(result == cudaSuccess); +} +#endif // JIT + +} // jittor diff --git a/python/jittor/extern/cuda/cutlass/ops/cutlass_matmul_op.h b/python/jittor/extern/cuda/cutlass/ops/cutlass_matmul_op.h new file mode 100644 index 00000000..b6cc0f75 --- /dev/null +++ b/python/jittor/extern/cuda/cutlass/ops/cutlass_matmul_op.h @@ -0,0 +1,25 @@ +// *************************************************************** +// Copyright (c) 2021 Jittor. All Rights Reserved. +// Maintainers: +// Guoye Yang <498731903@qq.com> +// Dun Liang . +// +// This file is subject to the terms and conditions defined in +// file 'LICENSE.txt', which is part of this source code package. +// *************************************************************** +#pragma once +#include "op.h" + +namespace jittor { + +struct CutlassMatmulOp : Op { + Var* a, * b, * c; + bool trans_a, trans_b; + CutlassMatmulOp(Var* a, Var* b, bool trans_a, bool trans_b); + + const char* name() const override { return "cutlass_matmul"; } + void infer_shape() override; + DECLARE_jit_run; +}; + +} // jittor \ No newline at end of file diff --git a/python/jittor/nn.py b/python/jittor/nn.py index c695b669..ad82d8e4 100644 --- a/python/jittor/nn.py +++ b/python/jittor/nn.py @@ -18,6 +18,7 @@ import numpy as np import collections import math +import os from collections import OrderedDict from jittor.pool import * from jittor.optim import * @@ -757,7 +758,10 @@ def group_norm(x, GELU = jt.make_module(gelu) Flatten = jt.make_module(jt.flatten) -from jittor.depthwise_conv import DepthwiseConv +if os.environ.get('use_cutlass') == '0': + from jittor.depthwise_conv import DepthwiseConv +else: + from jittor.cutlass_ops import DepthwiseConv, FullyFusedMlp class Conv(Module): ''' Applies a 2D convolution over an input signal composed of several input planes. @@ -2799,4 +2803,29 @@ def _fft2(x, inverse=False): y = jt.compile_extern.cufft_ops.cufft_fft(x, inverse) if inverse: y /= x.shape[1] * x.shape[2] - return y \ No newline at end of file + return y + + +class FullyFusedMLP(Module): + ''' fusing multiple linear layers in one FullyFusedMLP layer using ReLU activation. + Example:: + + m = nn.FullyFusedMLP((20, 30, 40)) # with 2 layers + input1 = jt.randn(128, 20) + output = m(input1) + print(output.shape) + # [128, 40] + ''' + def __init__(self, weights_width=None, weights=None): + assert os.environ.get("use_cutlass") == '1', "Need cutlass support!" + if weights != None: + self.weights = weights + else: + assert weights_width != None, "All inputs are None." + for idx in range(len(weights_width)-1): + self.weights.append(jt.randn(weights_width[idx], weights_width[idx+1])) + + self.ops = FullyFusedMlp() + + def execute(self, a): + return self.ops(a, *self.weights) \ No newline at end of file diff --git a/python/jittor/optim.py b/python/jittor/optim.py index 830378d7..86fd1925 100644 --- a/python/jittor/optim.py +++ b/python/jittor/optim.py @@ -211,10 +211,12 @@ def step(self, loss): def step(self, loss=None, retain_graph=False): self.pre_step(loss, retain_graph) + print("lr: ", lr) for pg in self.param_groups: lr = pg.get("lr", self.lr) for p, g in zip(pg["params"], pg["grads"]): if p.is_stop_grad(): continue + print("g: ", g) p.update(p - g * lr) self.post_step() @@ -386,6 +388,7 @@ def step(self, loss=None, retain_graph=False): m.update(b0 * m + (1-b0) * g) v.update(b1 * v + (1-b1) * g * g) step_size = lr * jt.sqrt(1-b1**n) / (1-b0 ** n) + # print("delta: ", m * step_size / (jt.sqrt(v) + eps)) p.update(p - m * step_size / (jt.sqrt(v) + eps)) self.post_step() diff --git a/python/jittor/src/jit_compiler.cc b/python/jittor/src/jit_compiler.cc index 9b7d75ce..63eadca4 100644 --- a/python/jittor/src/jit_compiler.cc +++ b/python/jittor/src/jit_compiler.cc @@ -211,7 +211,7 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c LOGvv << "Compile op" << jit_key; // compiler do not allowed filename too long CHECK(cc_path.size()); - string jit_src_path = Op::get_filename_from_jit_key(jit_key, ".cc"); + string jit_src_path = Op::get_filename_from_jit_key(jit_key, is_cuda_op?".cu":".cc"); string* src2 = (string*)&src; string* extra_flags2 = (string*)&extra_flags; JPU(op_compiler(jit_src_path, *src2, is_cuda_op, *extra_flags2)); @@ -235,6 +235,9 @@ jit_op_entry_t compile(const string& jit_key, const string& src, const bool is_c + " \"" + jit_src_path + "\"" + other_src + fix_cl_flags(nvcc_flags + extra_flags, is_cuda_op) + " -o \"" + jit_lib_path + "\""; + if (cmd.find("-dc") != string::npos) { + cmd = python_path+" "+jittor_path+"/utils/dlink_compiler.py " + cmd; + } } else { cmd = "\"" + cc_path + "\"" + " \"" + jit_src_path + "\"" + other_src diff --git a/python/jittor/src/op_compiler.cc b/python/jittor/src/op_compiler.cc index 11f753aa..8b42be33 100644 --- a/python/jittor/src/op_compiler.cc +++ b/python/jittor/src/op_compiler.cc @@ -731,11 +731,15 @@ string OpCompiler::get_jit_src(Op* op) { string name = op->name(); string name2 = Op::op_name_to_file_name(name); string name3 = Op::file_name_to_class_name(name2); + if (name == "fused") { string src = get_fused_src((FusedOp*)op); + // if (src.find("define op4_index_t int32") != string::npos) + // LOGir << "here here here" << src; ASSERT(src.size()); return src; } + // LOGir << "name: " << " " << name << " " << name2 << " " << name3; auto op_info = get_op_info(name); auto& src_path = op_info.source_path; @@ -759,15 +763,17 @@ string OpCompiler::get_jit_src(Op* op) { else after_include_src += src; } + // if (src.find("define op4_index_t int32") != string::npos) + // LOGir << "here here here" << src; ASSERT(file_exist(_to_winstr(src_path))) << src_path; LOGvvv << "Read from" << src_path; string src = read_all(_to_winstr(src_path)); ASSERT(src.size()) << "Source read failed:" << src_path; - unordered_map defs(jit_define.begin(), jit_define.end()); LOGvvv << "Precompile with key:" << defs; src = precompile(defs, src); - + // if (src.find("define op4_index_t int32") != string::npos) + // LOGir << "?????" << src; // find the last occur of #include "..."\n auto pos = src.rfind("#include"); if (pos == string::npos) pos=0; @@ -812,6 +818,7 @@ string OpCompiler::get_fused_src(FusedOp* op) { } Op* opi = op->ops[oi]; string src = get_jit_src(opi); + // LOGir << "@@@: " << src; op_srcs.push_back(move(src)); } return OpCompiler::__get_fused_src(op->ops, op_srcs, op_members); @@ -865,7 +872,7 @@ string OpCompiler::__get_fused_src( "int", "float", "bool", "CHECK", "STRINGIZE", "void", "__restrict__", "if", "true", "false", "Op", "Var", "Node", "itof", "assert", "ASSERT", - "float64" + "float64", "float16" }; auto not_change = [&](const string& s) -> bool { if (unchanged.count(s)) return true; @@ -914,6 +921,7 @@ string OpCompiler::__get_fused_src( std::regex_match(src, cm, e); ASSERT(cm.size()>=2) << src; string name3 = cm[1]; + for (uint i=0; ifind("ops/broadcast_to_op.h") != string::npos && src->find("") == string::npos && src->find("op4_Tx") != string::npos && src->find("op9_Tx") == string::npos) + // LOGir << "!!!!" << *src; for (auto op_type : op_types) op_type->post_pass(&oc); string src_after_passes; + // if (src->find("ops/broadcast_to_op.h") != string::npos && src->find("") == string::npos && src->find("op6_Tx") != string::npos) + // LOGir << "????" << *src; // if is fused op if (oc.op) { + // LOGir << "????????????? " << *src; TunerManager tm(&oc); src_after_passes = tm.tune(); src = &src_after_passes; } + // if (src->find("ops/broadcast_to_op.h") != string::npos && src->find("") == string::npos && src->find("op4_Tx") != string::npos && src->find("op9_Tx") == string::npos) + // if (src->find("58aff3bc47edee4")!=string::npos) + // LOGir << "1233243423 " << *src << oc.op; op->compile_optimize(*src); auto ret = oc.compile(op->get_jit_key(get_jk()), *src); return ret; diff --git a/python/jittor/src/opt/pass_manager.cc b/python/jittor/src/opt/pass_manager.cc index 07c3eebf..534154ef 100644 --- a/python/jittor/src/opt/pass_manager.cc +++ b/python/jittor/src/opt/pass_manager.cc @@ -63,8 +63,10 @@ bool PassManager::check(Pass* pass) { } void PassManager::run_passes() { + if(oc->op->get_hash_name() == "58aff3bc47edee4") + LOGir << "hahaha1" << all.to_string(); auto& ir = *main_ir; - + // LOGir << all.to_string(); LOGvvvv << "KernelIR:\n" << ir.to_string(); if (oc->op->ops.size() == 1 && oc->op->ops[0]->name() == string("array")) { ir.remove_all_unused(); @@ -77,6 +79,7 @@ void PassManager::run_passes() { } return; } + run_pass(); run_pass(); run_pass(); @@ -96,7 +99,6 @@ void PassManager::run_passes() { run_pass(); // tmp disable ConstVarPass // run_pass(); - run_pass(); if (cc_type == "icc") { @@ -106,17 +108,21 @@ void PassManager::run_passes() { run_pass(); } run_pass(); + run_pass(); + run_pass(); + // if(all.to_string().find("58aff3bc47edee4") != string::npos) + // LOGir << "hahaha1" << all.to_string(); run_pass(); + run_pass(); run_pass(); + run_pass(); - run_pass(); run_pass(); - run_pass(); } diff --git a/python/jittor/src/opt/tuner_manager.cc b/python/jittor/src/opt/tuner_manager.cc index b6e01207..84a12b0f 100644 --- a/python/jittor/src/opt/tuner_manager.cc +++ b/python/jittor/src/opt/tuner_manager.cc @@ -43,20 +43,31 @@ string TunerManager::tune() { run_tuner(&pm); run_tuner(&pm); run_tuner(&pm); - + // if (pm.all.to_string().find("define op4_index_t int32") != string::npos) + // LOGir << "?????" << pm.all.to_string(); // use the best tuner if it is confidence enough if (best_tuner && best_tuner->confidence) { - if (jit_search_kernel) + + if (jit_search_kernel) { + // LOGir << "?????" << pm.all.to_string(); searcher.search(best_tuner->candidates); + } else { if (best_tuner->confidence >= 10) { auto& loop_options = oc->op->get_loop_options_tuned(); for (auto& kv : best_tuner->candidates) loop_options[kv.first] = kv.second.front(); oc->op->update_jit_key(); - PassManager pm(oc); + // if (oc->get_src().find("define op4_index_t int32") != string::npos) + // LOGir << "?????" << oc->get_src(); + string* src = &oc->src; + PassManager pm(oc); + // if (pm.all.to_string().find("define op4_index_t int32") != string::npos) + // LOGir << "?????" << pm.all.to_string(); pm.run_passes(); src_after_passes = pm.all.to_string(); + // if (pm.all.to_string().find("58aff3bc47edee4") != string::npos) + // LOGir << "?????" << pm.all.to_string(); } } } diff --git a/python/jittor/src/profiler/profiler.cc b/python/jittor/src/profiler/profiler.cc index fbf879cd..8761d881 100644 --- a/python/jittor/src/profiler/profiler.cc +++ b/python/jittor/src/profiler/profiler.cc @@ -428,7 +428,10 @@ vector> Profiler::report(const string& sort_key) { continue; auto& fname = fnames[i]; rep.push_back({name, fname}); - ss << std::setw(w) << name; + if (name.size() > 100) + ss << std::setw(w) << name.substr(0, 100); + else + ss << std::setw(w) << name; if (name.size() >= w-1) ss << "\n" << std::setw(w) << " "; ss << std::setw(w) << fname; diff --git a/python/jittor/src/type/fp16_op_type.cc b/python/jittor/src/type/fp16_op_type.cc index 9aef9b03..851d7217 100644 --- a/python/jittor/src/type/fp16_op_type.cc +++ b/python/jittor/src/type/fp16_op_type.cc @@ -8,6 +8,7 @@ #include "utils/str_utils.h" #include "ops/op_register.h" #include "op_compiler.h" +#include "fused_op.h" namespace jittor { @@ -171,6 +172,10 @@ struct FP16OpType : OpByType { void post_pass(OpCompiler* oc) { string& src = oc->src; + if(oc->op && oc->op->get_hash_name() == "58aff3bc47edee4") + LOGir << src; + // if (src.find("ops/broadcast_to_op.h") != string::npos && src.find("") == string::npos) + // LOGir << "sssss " << src; if (src.find("float16") == string::npos) return; int i = src.rfind("#include"); @@ -178,6 +183,7 @@ struct FP16OpType : OpByType { i = src.find('\n', i) + 1; src = src.substr(0, i) + "#include \"type/fp16_compute.h\"\n" + src.substr(i); + // LOGir << "eeeee " << src; return; } }; diff --git a/python/jittor/utils/dlink_compiler.py b/python/jittor/utils/dlink_compiler.py new file mode 100644 index 00000000..ba1b9ffb --- /dev/null +++ b/python/jittor/utils/dlink_compiler.py @@ -0,0 +1,26 @@ +import sys +import os +import re +cmds = sys.argv[1:] +def replace(cmds, s, t): + return [ c.replace(s,t) for c in cmds ] +def remove(cmds, ss): + rets = [] + for cmd in cmds: + found = True + for s in ss: + if s in cmd: + found = False + break + if found: + rets.append(cmd) + return rets + +cmds1 = remove(cmds, [".o"]) +cmds1 = replace(cmds1, ".so", ".o") +cmds2 = replace(cmds, "-dc", "") +cmds2 = replace(cmds2, ".cu", ".o") +ret = os.system(" ".join(cmds1).replace("-x cu", "")) +if ret: exit(ret) +ret = os.system(" ".join(cmds2).replace("-x cu", "")) +if ret: exit(ret) \ No newline at end of file