diff --git a/ffcx/codegeneration/C/integrals.py b/ffcx/codegeneration/C/integrals.py index 1115e3731..226b2dabc 100644 --- a/ffcx/codegeneration/C/integrals.py +++ b/ffcx/codegeneration/C/integrals.py @@ -69,6 +69,13 @@ def generator(ir: IntegralIR, options): else: code["tabulate_tensor_complex64"] = ".tabulate_tensor_complex64 = NULL," code["tabulate_tensor_complex128"] = ".tabulate_tensor_complex128 = NULL," + if options.get("cuda"): + code["tabulate_tensor_cuda_nvrtc"] = ( + f".tabulate_tensor_cuda_nvrtc = tabulate_tensor_cuda_nvrtc_{factory_name}" + ) + else: + code["tabulate_tensor_cuda_nvrtc"] = "" + np_scalar_type = np.dtype(options["scalar_type"]).name code[f"tabulate_tensor_{np_scalar_type}"] = ( f".tabulate_tensor_{np_scalar_type} = tabulate_tensor_{factory_name}," @@ -76,7 +83,7 @@ def generator(ir: IntegralIR, options): element_hash = 0 if ir.coordinate_element_hash is None else ir.coordinate_element_hash - implementation = ufcx_integrals.factory.format( + implementation = ufcx_integrals.get_factory(options).format( factory_name=factory_name, enabled_coefficients=code["enabled_coefficients"], enabled_coefficients_init=code["enabled_coefficients_init"], @@ -89,6 +96,7 @@ def generator(ir: IntegralIR, options): tabulate_tensor_float64=code["tabulate_tensor_float64"], tabulate_tensor_complex64=code["tabulate_tensor_complex64"], tabulate_tensor_complex128=code["tabulate_tensor_complex128"], + tabulate_tensor_cuda_nvrtc=code["tabulate_tensor_cuda_nvrtc"], ) return declaration, implementation diff --git a/ffcx/codegeneration/C/integrals_template.py b/ffcx/codegeneration/C/integrals_template.py index 2bb1568ec..28e36a508 100644 --- a/ffcx/codegeneration/C/integrals_template.py +++ b/ffcx/codegeneration/C/integrals_template.py @@ -30,9 +30,60 @@ {tabulate_tensor_float64} {tabulate_tensor_complex64} {tabulate_tensor_complex128} + {tabulate_tensor_cuda_nvrtc} .needs_facet_permutations = {needs_facet_permutations}, .coordinate_element_hash = {coordinate_element_hash}, }}; // End of code for integral {factory_name} """ + +cuda_wrapper = """ + +// Begin NVRTC CUDA wrapper for integral {factory_name} +// The wrapper is compiled with a standard C++ compiler, and is called at runtime to generate +// source code which is then compiled into a CUDA kernel at runtime via NVRTC. +void tabulate_tensor_cuda_nvrtc_{factory_name}(int* num_program_headers, + const char*** program_headers, + const char*** program_include_names, + const char** out_program_src, + const char** tabulate_tensor_function_name) +{{ + // The below typedefs are needed due to issues with including stdint.h in NVRTC source code + const char* program_src = "" + "#define alignas(x)\\n" + "#define restrict __restrict__\\n" + "\\n" + "typedef unsigned char uint8_t;\\n" + "typedef unsigned int uint32_t;\\n" + "typedef double ufc_scalar_t;\\n" + "\\n" + "extern \\"C\\" __global__\\n" + "void tabulate_tensor_{factory_name}({scalar_type}* restrict A,\\n" + " const {scalar_type}* restrict w,\\n" + " const {scalar_type}* restrict c,\\n" + " const {geom_type}* restrict coordinate_dofs,\\n" + " const int* restrict entity_local_index,\\n" + " const uint8_t* restrict quadrature_permutation\\n" + " )\\n" + "{{\\n" + "{tabulate_tensor_quoted}\\n" + "}}"; + *num_program_headers = 0; + *program_headers = NULL; + *program_include_names = NULL; + *out_program_src = program_src; + *tabulate_tensor_function_name = "tabulate_tensor_{factory_name}"; +}} + +// End NVRTC CUDA wrapper for integral {factory_name} + +""" + + +def get_factory(options): + """Return the template string for constructing form integrals.""" + if options.get("cuda"): + return cuda_wrapper + factory + else: + return factory diff --git a/ffcx/codegeneration/jit.py b/ffcx/codegeneration/jit.py index 6eb5dbb8f..467de689c 100644 --- a/ffcx/codegeneration/jit.py +++ b/ffcx/codegeneration/jit.py @@ -68,6 +68,9 @@ UFC_INTEGRAL_DECL += "\n".join( re.findall(r"typedef void ?\(ufcx_tabulate_tensor_complex128\).*?\);", ufcx_h, re.DOTALL) ) +UFC_INTEGRAL_DECL += "\n".join( + re.findall(r"typedef void ?\(ufcx_tabulate_tensor_cuda_nvrtc\).*?\);", ufcx_h, re.DOTALL) +) UFC_INTEGRAL_DECL += "\n".join( re.findall("typedef struct ufcx_integral.*?ufcx_integral;", ufcx_h, re.DOTALL) diff --git a/ffcx/codegeneration/ufcx.h b/ffcx/codegeneration/ufcx.h index e1dd838d1..782a1a04d 100644 --- a/ffcx/codegeneration/ufcx.h +++ b/ffcx/codegeneration/ufcx.h @@ -125,6 +125,28 @@ extern "C" const uint8_t* restrict quadrature_permutation); #endif // __STDC_NO_COMPLEX__ + /// Return CUDA C++ source code for the ufc_tabulate_tensor kernel + /// The resulting source code is passed to NVRTC for runtime compilation + /// + /// @param[out] num_program_headers + /// The number of headers required by the program + /// @param[out] program_headers + /// Entire contents of each header file + /// @param[out] program_include_names + /// Names of each header file + /// @param[out] program_src + /// CUDA C++ source code for the program containing the + /// tabulate_tensor function. + /// @param[out] tabulate_tensor_function_name + /// The name of the device-side function. + /// + typedef void(ufcx_tabulate_tensor_cuda_nvrtc)( + int* num_program_headers, + const char*** program_headers, + const char*** program_include_names, + const char** program_src, + const char** tabulate_tensor_function_name); + typedef struct ufcx_integral { const bool* enabled_coefficients; @@ -134,6 +156,7 @@ extern "C" ufcx_tabulate_tensor_complex64* tabulate_tensor_complex64; ufcx_tabulate_tensor_complex128* tabulate_tensor_complex128; #endif // __STDC_NO_COMPLEX__ + ufcx_tabulate_tensor_cuda_nvrtc* tabulate_tensor_cuda_nvrtc; bool needs_facet_permutations; /// Get the hash of the coordinate element associated with the geometry of the mesh. diff --git a/ffcx/options.py b/ffcx/options.py index 536f02a35..252ea74a8 100644 --- a/ffcx/options.py +++ b/ffcx/options.py @@ -20,6 +20,7 @@ logger = logging.getLogger("ffcx") FFCX_DEFAULT_OPTIONS = { + "cuda": (bool, False, "generate CUDA wrapped versions of tabulate tensor functions", None), "epsilon": (float, 1e-14, "machine precision, used for dropping zero terms in tables.", None), "scalar_type": ( str,