Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow user to define the maximum number of kernel arguments #718

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ opt
/.compiledDefines
/include/occa/defines/compiledDefines.hpp
/include/occa/scripts
/include/occa/core/kernelOperators.hpp_codegen
/src/core/kernelOperators.cpp_codegen
/src/occa/internal/utils/runFunction.cpp_codegen
/include/occa/defines/macros.hpp_codegen

# Binaries generated to fetch compiler information
/scripts/compiler/compilerSupportsMPI
Expand Down
14 changes: 14 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,20 @@ else()
add_compile_definitions(OCCA_THREAD_SHARABLE_ENABLED=0)
endif()

set(MAX_NUM_KERNEL_ARGS_DEFAULT "128")
set(MAX_NUM_KERNEL_ARGS ${MAX_NUM_KERNEL_ARGS_DEFAULT} CACHE STRING "The maximum number of allowed kernel arguments")
if (${MAX_NUM_KERNEL_ARGS} GREATER ${MAX_NUM_KERNEL_ARGS_DEFAULT})
execute_process(COMMAND python --version OUTPUT_VARIABLE python_version)
string(REGEX MATCH "[0-9.]\+" python_version ${python_version})
if ("${python_version}" VERSION_LESS "3.7.2")
message(WARNING "-- Failed to set the maximum number of kernel arguments to ${MAX_NUM_KERNEL_ARGS}, required minimum python version 3.7.2. The default value ${MAX_NUM_KERNEL_ARGS_DEFAULT} will be used.")
else()
message("-- Codegen for the maximum number of kernel arguments : ${MAX_NUM_KERNEL_ARGS}")
execute_process(COMMAND ${CMAKE_COMMAND} -E env OCCA_DIR=${CMAKE_CURRENT_SOURCE_DIR} python ${CMAKE_CURRENT_SOURCE_DIR}/scripts/codegen/setup_kernel_operators.py -N ${MAX_NUM_KERNEL_ARGS})
endif()
endif()
add_compile_definitions(OCCA_MAX_ARGS=${MAX_NUM_KERNEL_ARGS})

set(OCCA_SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR})
set(OCCA_BUILD_DIR ${CMAKE_BINARY_DIR})

Expand Down
3 changes: 2 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,9 @@ MAKE_COMPILED_DEFINES := $(shell cat "$(OCCA_DIR)/scripts/build/compiledDefinesT
s,@@OCCA_HIP_ENABLED@@,$(OCCA_HIP_ENABLED),g;\
s,@@OCCA_OPENCL_ENABLED@@,$(OCCA_OPENCL_ENABLED),g;\
s,@@OCCA_METAL_ENABLED@@,$(OCCA_METAL_ENABLED),g;\
s,@@OCCA_DPCPP_ENABLED@@,$(OCCA_DPCPP_ENABLED),g;\
s,@@OCCA_DPCPP_ENABLED@@,$(OCCA_DPCPP_ENABLED),g;\
s,@@OCCA_THREAD_SHARABLE_ENABLED@@,$(OCCA_THREAD_SHARABLE_ENABLED),g;\
s,@@OCCA_MAX_ARGS@@,$(OCCA_MAX_ARGS),g;\
s,@@OCCA_BUILD_DIR@@,$(OCCA_BUILD_DIR),g;"\
> "$(NEW_COMPILED_DEFINES)")

Expand Down
33 changes: 1 addition & 32 deletions include/occa/defines/macros.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,38 +24,7 @@
// Just in case someone wants to run with an older format than C99
#ifndef OCCA_DISABLE_VARIADIC_MACROS

# define OCCA_ARG_COUNT(...) OCCA_ARG_COUNT2(\
__VA_ARGS__, \
128, 127, 126, 125, 124, 123, 122, 121, \
120, 119, 118, 117, 116, 115, 114, 113, 112, 111, \
110, 109, 108, 107, 106, 105, 104, 103, 102, 101, \
100, 99, 98, 97, 96, 95, 94, 93, 92, 91, \
90, 89, 88, 87, 86, 85, 84, 83, 82, 81, \
80, 79, 78, 77, 76, 75, 74, 73, 72, 71, \
70, 69, 68, 67, 66, 65, 64, 63, 62, 61, \
60, 59, 58, 57, 56, 55, 54, 53, 52, 51, \
50, 49, 48, 47, 46, 45, 44, 43, 42, 41, \
40, 39, 38, 37, 36, 35, 34, 33, 32, 31, \
30, 29, 28, 27, 26, 25, 24, 23, 22, 21, \
20, 19, 18, 17, 16, 15, 14, 13, 12, 11, \
10, 9, 8, 7, 6, 5, 4, 3, 2, 1, \
0)

# define OCCA_ARG_COUNT2( \
_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, \
_11, _12, _13, _14, _15, _16, _17, _18, _19, _20, \
_21, _22, _23, _24, _25, _26, _27, _28, _29, _30, \
_31, _32, _33, _34, _35, _36, _37, _38, _39, _40, \
_41, _42, _43, _44, _45, _46, _47, _48, _49, _50, \
_51, _52, _53, _54, _55, _56, _57, _58, _59, _60, \
_61, _62, _63, _64, _65, _66, _67, _68, _69, _70, \
_71, _72, _73, _74, _75, _76, _77, _78, _79, _80, \
_81, _82, _83, _84, _85, _86, _87, _88, _89, _90, \
_91, _92, _93, _94, _95, _96, _97, _98, _99, _100, \
_101, _102, _103, _104, _105, _106, _107, _108, _109, _110, \
_111, _112, _113, _114, _115, _116, _117, _118, _119, _120, \
_121, _122, _123, _124, _125, _126, _127, _128, \
N, ...) N
#include "macros.hpp_codegen"

#endif // OCCA_DISABLE_VARIADIC_MACROS

Expand Down
38 changes: 38 additions & 0 deletions include/occa/defines/macros.hpp_codegen
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
// -------------[ DO NOT EDIT ]-------------
// THIS IS AN AUTOMATICALLY GENERATED FILE
// EDIT: scripts/codegen/setup_kernel_operators.py
// =========================================

# define OCCA_ARG_COUNT2( \
_1, _2, _3, _4, _5, _6, _7, _8, _9, _10, \
_11, _12, _13, _14, _15, _16, _17, _18, _19, _20, \
_21, _22, _23, _24, _25, _26, _27, _28, _29, _30, \
_31, _32, _33, _34, _35, _36, _37, _38, _39, _40, \
_41, _42, _43, _44, _45, _46, _47, _48, _49, _50, \
_51, _52, _53, _54, _55, _56, _57, _58, _59, _60, \
_61, _62, _63, _64, _65, _66, _67, _68, _69, _70, \
_71, _72, _73, _74, _75, _76, _77, _78, _79, _80, \
_81, _82, _83, _84, _85, _86, _87, _88, _89, _90, \
_91, _92, _93, _94, _95, _96, _97, _98, _99, _100, \
_101, _102, _103, _104, _105, _106, _107, _108, _109, _110, \
_111, _112, _113, _114, _115, _116, _117, _118, _119, _120, \
_121, _122, _123, _124, _125, _126, _127, _128, \
N, ...) N

# define OCCA_ARG_COUNT(...) OCCA_ARG_COUNT2( \
__VA_ARGS__, \
128, 127, 126, 125, 124, 123, 122, 121, \
120, 119, 118, 117, 116, 115, 114, 113, 112, 111, \
110, 109, 108, 107, 106, 105, 104, 103, 102, 101, \
100, 99, 98, 97, 96, 95, 94, 93, 92, 91, \
90, 89, 88, 87, 86, 85, 84, 83, 82, 81, \
80, 79, 78, 77, 76, 75, 74, 73, 72, 71, \
70, 69, 68, 67, 66, 65, 64, 63, 62, 61, \
60, 59, 58, 57, 56, 55, 54, 53, 52, 51, \
50, 49, 48, 47, 46, 45, 44, 43, 42, 41, \
40, 39, 38, 37, 36, 35, 34, 33, 32, 31, \
30, 29, 28, 27, 26, 25, 24, 23, 22, 21, \
20, 19, 18, 17, 16, 15, 14, 13, 12, 11, \
10, 9, 8, 7, 6, 5, 4, 3, 2, 1, \
0)

2 changes: 0 additions & 2 deletions include/occa/defines/occa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
#define OKL_VERSION 10600
#define OKL_VERSION_STR "1.6.0"

#define OCCA_MAX_ARGS 128

#define OCCA_DEFAULT_MEM_BYTE_ALIGN 32

#endif
44 changes: 28 additions & 16 deletions scripts/build/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -252,15 +252,16 @@ endif


#---[ Variable Dependencies ]---------------------
fortranEnabled = 0
mpiEnabled = 0
openmpEnabled = 0
cudaEnabled = 0
hipEnabled = 0
openclEnabled = 0
metalEnabled = 0
dpcppEnabled = 0

fortranEnabled = 0
mpiEnabled = 0
openmpEnabled = 0
cudaEnabled = 0
hipEnabled = 0
openclEnabled = 0
metalEnabled = 0
dpcppEnabled = 0
threadSharableEnabled = 0
maxArgs = 128

#---[ Fortran ]-------------------------
ifdef OCCA_FORTRAN_ENABLED
Expand Down Expand Up @@ -480,6 +481,15 @@ ifeq ($(usingMacOS),1)
endif
endif

#---[ Other parameters ]---------------------------
ifdef OCCA_THREAD_SHARABLE_ENABLED
threadSharableEnabled = $(OCCA_THREAD_SHARABLE_ENABLED)
endif

ifdef OCCA_MAX_ARGS
maxArgs = $(OCCA_MAX_ARGS)
endif

ifeq ($(cudaEnabled),1)
compilerFlags += -Wno-c++11-long-long
endif
Expand All @@ -491,11 +501,13 @@ else
OCCA_CHECK_ENABLED := 0
endif

OCCA_FORTRAN_ENABLED := $(fortranEnabled)
OCCA_OPENMP_ENABLED := $(openmpEnabled)
OCCA_CUDA_ENABLED := $(cudaEnabled)
OCCA_HIP_ENABLED := $(hipEnabled)
OCCA_OPENCL_ENABLED := $(openclEnabled)
OCCA_METAL_ENABLED := $(metalEnabled)
OCCA_DPCPP_ENABLED := $(dpcppEnabled)
OCCA_FORTRAN_ENABLED := $(fortranEnabled)
OCCA_OPENMP_ENABLED := $(openmpEnabled)
OCCA_CUDA_ENABLED := $(cudaEnabled)
OCCA_HIP_ENABLED := $(hipEnabled)
OCCA_OPENCL_ENABLED := $(openclEnabled)
OCCA_METAL_ENABLED := $(metalEnabled)
OCCA_DPCPP_ENABLED := $(dpcppEnabled)
OCCA_THREAD_SHARABLE_ENABLED := $(threadSharableEnabled)
OCCA_MAX_ARGS := $(maxArgs)
#=================================================
3 changes: 3 additions & 0 deletions scripts/build/compiledDefinesTemplate.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
#define OCCA_METAL_ENABLED @@OCCA_METAL_ENABLED@@
#define OCCA_DPCPP_ENABLED @@OCCA_DPCPP_ENABLED@@

#define OCCA_THREAD_SHARABLE_ENABLED @@OCCA_THREAD_SHARABLE_ENABLED@@
#define OCCA_MAX_ARGS @@OCCA_MAX_ARGS@@

#define OCCA_BUILD_DIR "@@OCCA_BUILD_DIR@@"

#endif
41 changes: 39 additions & 2 deletions scripts/codegen/setup_kernel_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import os
import functools

import argparse

OCCA_DIR = os.environ.get(
'OCCA_DIR',
Expand Down Expand Up @@ -79,7 +79,7 @@ def run_function_from_arguments(N):
content = '\nswitch (argc) {\n'
for n in range(N + 1):
content += run_function_from_argument(n)
content += '}\n';
content += ' default:\n OCCA_FORCE_ERROR("TOO MANY KERNEL ARGUMENTS REQUESTED");\n}\n'

return content

Expand Down Expand Up @@ -148,7 +148,44 @@ def operator_definition(N):
'''
return content

def macro_count2(N):
content = '# define OCCA_ARG_COUNT2( \\\n'
indent=' ' * 2
for n in range(1, N+1):
if n % 10 == 1:
content += indent
content += '_' + str(n) + ', '
if n % 10 == 0:
content += '\\\n'
if N % 10 > 0:
content += '\\\n'
content += indent + 'N, ...) N\n'
return content

def macro_count(N):
content = '# define OCCA_ARG_COUNT(...) OCCA_ARG_COUNT2( \\\n'
indent=' ' * 2
content += indent + '__VA_ARGS__, \\\n' + indent
for n in range(N, 0, -1):
content += str(n) + ', '
if n % 10 == 1:
content += '\\\n' + indent
content += '0)\n'
return content

@to_file('include/occa/defines/macros.hpp_codegen')
def macro_declarations(N):
return ''.join(
macro_count2(N) + '\n' + macro_count(N)
)

if __name__ == '__main__':
parser = argparse.ArgumentParser(usage=__doc__)
parser.add_argument("-N","--NargsMax", type=int, default=MAX_ARGS)
args = parser.parse_args()
MAX_ARGS = args.NargsMax

run_function_from_arguments(MAX_ARGS)
operator_declarations(MAX_ARGS)
operator_definitions(MAX_ARGS)
macro_declarations(MAX_ARGS)
1 change: 1 addition & 0 deletions src/core/kernelOperators.cpp_codegen
Original file line number Diff line number Diff line change
Expand Up @@ -3490,3 +3490,4 @@ void kernel::operator() (const kernelArg &arg1, const kernelArg &arg2, const ker
modeKernel->setArguments(args, 128);
run();
}

2 changes: 1 addition & 1 deletion src/occa/internal/utils/runFunction.cpp_codegen
Original file line number Diff line number Diff line change
Expand Up @@ -1304,7 +1304,7 @@ switch (argc) {
args[95], args[96], args[97], args[98], args[99],
args[100], args[101]);
break;
case 103:
case 103:
f(args[0], args[1], args[2], args[3], args[4],
args[5], args[6], args[7], args[8], args[9],
args[10], args[11], args[12], args[13], args[14],
Expand Down