Skip to content

Commit

Permalink
[eudsl][llvmpy] extend bindings (AMDGCN) (#53)
Browse files Browse the repository at this point in the history
  • Loading branch information
makslevental authored Jan 26, 2025
1 parent dbb16f6 commit 10a2815
Show file tree
Hide file tree
Showing 14 changed files with 1,511 additions and 137 deletions.
38 changes: 25 additions & 13 deletions .github/workflows/build_test_release_eudsl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -130,22 +130,24 @@ jobs:
pip install cibuildwheel
- name: "Build eudsl-llvmpy"
if: ${{ ! startsWith(matrix.os, 'windows') }}
- name: "Build eudsl-tblgen"
run: |
if [[ "${{ matrix.os }}" == "ubuntu" ]]; then
export CCACHE_DIR=/host/$CCACHE_DIR
fi
$python3_command -m cibuildwheel "$PWD/projects/eudsl-llvmpy" --output-dir wheelhouse
$python3_command -m cibuildwheel "$PWD/projects/eudsl-tblgen" --output-dir wheelhouse
- name: "Build eudsl-tblgen"
- name: "Build eudsl-llvmpy"
if: ${{ ! startsWith(matrix.os, 'windows') }}
run: |
export PIP_FIND_LINKS=$PWD/wheelhouse
if [[ "${{ matrix.os }}" == "ubuntu" ]]; then
export CCACHE_DIR=/host/$CCACHE_DIR
export PIP_FIND_LINKS=/host/$PIP_FIND_LINKS
fi
$python3_command -m cibuildwheel "$PWD/projects/eudsl-tblgen" --output-dir wheelhouse
$python3_command -m cibuildwheel "$PWD/projects/eudsl-llvmpy" --output-dir wheelhouse
- name: "Build eudsl-nbgen"
run: |
Expand All @@ -159,20 +161,28 @@ jobs:
if: ${{ ! startsWith(matrix.os, 'windows') }}
run: |
if [[ "${{ matrix.os }}" == "ubuntu" ]]; then
export CCACHE_DIR=/host/$CCACHE_DIR
if [[ "${{ github.event_name }}" == "workflow_dispatch" ]]; then
if [[ "${{ matrix.os }}" == "ubuntu" ]]; then
export CCACHE_DIR=/host/$CCACHE_DIR
fi
$python3_command -m cibuildwheel "$PWD/projects/eudsl-py" --output-dir wheelhouse
else
export CMAKE_PREFIX_PATH=$PWD/llvm-install
export PIP_FIND_LINKS=$PWD/wheelhouse
$python3_command -m pip wheel "$PWD/projects/eudsl-py" -w wheelhouse -v
fi
$python3_command -m cibuildwheel "$PWD/projects/eudsl-py" --output-dir wheelhouse
# just to/make sure total build continues to work
- name: "Build all of eudsl"
run: |
pip install -r requirements.txt
$python3_command -m pip install -r requirements.txt
$python3_command -m pip install eudsl-tblgen -f wheelhouse
cmake -B $PWD/eudsl-build -S $PWD \
-DCMAKE_PREFIX_PATH=$PWD/llvm-install \
-DCMAKE_BUILD_TYPE=Release \
-DCMAKE_INSTALL_PREFIX=$PWD/eudsl-install
-DCMAKE_INSTALL_PREFIX=$PWD/eudsl-install \
-DPython3_EXECUTABLE=$(which $python3_command)
cmake --build "$PWD/eudsl-build" --target install
- name: "Save cache"
Expand Down Expand Up @@ -263,7 +273,9 @@ jobs:
"macos-14"
# "windows-2019"
]
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: [
# "3.9", "3.10", "3.11",
"3.12"]
include: [
{runs-on: "ubuntu-22.04", name: "ubuntu_x86_64", os: "ubuntu"},
# TODO(max): enable on windows by statically linking
Expand Down
5 changes: 4 additions & 1 deletion projects/eudsl-llvmpy/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,11 @@ include_directories(${EUDSLLLVM_SRC_DIR})

execute_process(
COMMAND "${Python_EXECUTABLE}" "${CMAKE_CURRENT_LIST_DIR}/eudsl-llvmpy-generate.py"
${LLVM_INCLUDE_DIRS}/llvm-c "${EUDSLLLVM_BINARY_DIR}/generated"
${LLVM_INCLUDE_DIRS}
"${EUDSLLLVM_BINARY_DIR}/generated"
"${EUDSLLLVM_SRC_DIR}/llvm"
RESULT_VARIABLE _has_err_generate
COMMAND_ECHO STDOUT
)
if (_has_err_generate AND NOT _has_err_generate EQUAL 0)
message(FATAL_ERROR "couldn't generate sources: ${_has_err_generate}")
Expand Down
272 changes: 263 additions & 9 deletions projects/eudsl-llvmpy/eudsl-llvmpy-generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from textwrap import dedent

import litgen
from eudsl_tblgen import RecordKeeper


def preprocess_code(code: str, here, header_f) -> str:
Expand Down Expand Up @@ -45,18 +46,20 @@ def replacement(s):
transformed_code = transformed_code.replace(
"typedef const void *LLVMErrorTypeId;", "typedef void *LLVMErrorTypeId;"
)
transformed_code = transformed_code.replace(
"extern const void*", "extern void*"
)
transformed_code = transformed_code.replace("extern const void*", "extern void*")
transformed_code = transformed_code.replace("/**", "/*")

pattern = "^LLVM_C_EXTERN_C_BEGIN"
replacement = 'extern "C" {'
transformed_code = re.sub(pattern, replacement, transformed_code, flags=re.MULTILINE)
transformed_code = re.sub(
pattern, replacement, transformed_code, flags=re.MULTILINE
)

pattern = "^LLVM_C_EXTERN_C_END"
replacement = "}"
transformed_code = re.sub(pattern, replacement, transformed_code, flags=re.MULTILINE)
transformed_code = re.sub(
pattern, replacement, transformed_code, flags=re.MULTILINE
)

return transformed_code

Expand Down Expand Up @@ -84,10 +87,258 @@ def generate_header_bindings(cpp_code):
return generated_code.pydef_code


def main(header_root, output_root):
__normalize_python_kws = {"class": "class_", "if": "if_", "else": "else_"}


def generate_amdgcn_intrinsics(llvm_include_root: Path, llvmpy_module_dir: Path):
amdgcn_f = open(llvmpy_module_dir / "amdgcn.py", "w")
print(
dedent(
"""\
from typing import NewType, TypeVar, Generic
from .util import call_intrinsic
from . import ValueRef
any = NewType("any", ValueRef)
anyfloat = NewType("anyfloat", ValueRef)
anyint = NewType("anyint", ValueRef)
anyptr = NewType("anyptr", ValueRef)
anyvector = NewType("anyvector", ValueRef)
bfloat = NewType("bfloat", ValueRef)
double = NewType("double", ValueRef)
fp128 = NewType("fp128", ValueRef)
fp80 = NewType("fp80", ValueRef)
float = NewType("float", ValueRef)
half = NewType("half", ValueRef)
int1 = NewType("int1", ValueRef)
int128 = NewType("int128", ValueRef)
int16 = NewType("int16", ValueRef)
int32 = NewType("int32", ValueRef)
int64 = NewType("int64", ValueRef)
int8 = NewType("int8", ValueRef)
ppcfp128 = NewType("ppcfp128", ValueRef)
pointer = NewType("pointer", ValueRef)
void = NewType("void", ValueRef)
v1i1 = NewType("v1i1", ValueRef)
v2i1 = NewType("v2i1", ValueRef)
v3i1 = NewType("v3i1", ValueRef)
v4i1 = NewType("v4i1", ValueRef)
v8i1 = NewType("v8i1", ValueRef)
v16i1 = NewType("v16i1", ValueRef)
v32i1 = NewType("v32i1", ValueRef)
v64i1 = NewType("v64i1", ValueRef)
v128i1 = NewType("v128i1", ValueRef)
v256i1 = NewType("v256i1", ValueRef)
v512i1 = NewType("v512i1", ValueRef)
v1024i1 = NewType("v1024i1", ValueRef)
v2048i1 = NewType("v2048i1", ValueRef)
v128i2 = NewType("v128i2", ValueRef)
v256i2 = NewType("v256i2", ValueRef)
v64i4 = NewType("v64i4", ValueRef)
v128i4 = NewType("v128i4", ValueRef)
v1i8 = NewType("v1i8", ValueRef)
v2i8 = NewType("v2i8", ValueRef)
v3i8 = NewType("v3i8", ValueRef)
v4i8 = NewType("v4i8", ValueRef)
v8i8 = NewType("v8i8", ValueRef)
v16i8 = NewType("v16i8", ValueRef)
v32i8 = NewType("v32i8", ValueRef)
v64i8 = NewType("v64i8", ValueRef)
v128i8 = NewType("v128i8", ValueRef)
v256i8 = NewType("v256i8", ValueRef)
v512i8 = NewType("v512i8", ValueRef)
v1024i8 = NewType("v1024i8", ValueRef)
v1i16 = NewType("v1i16", ValueRef)
v2i16 = NewType("v2i16", ValueRef)
v3i16 = NewType("v3i16", ValueRef)
v4i16 = NewType("v4i16", ValueRef)
v8i16 = NewType("v8i16", ValueRef)
v16i16 = NewType("v16i16", ValueRef)
v32i16 = NewType("v32i16", ValueRef)
v64i16 = NewType("v64i16", ValueRef)
v128i16 = NewType("v128i16", ValueRef)
v256i16 = NewType("v256i16", ValueRef)
v512i16 = NewType("v512i16", ValueRef)
v1i32 = NewType("v1i32", ValueRef)
v2i32 = NewType("v2i32", ValueRef)
v3i32 = NewType("v3i32", ValueRef)
v4i32 = NewType("v4i32", ValueRef)
v5i32 = NewType("v5i32", ValueRef)
v6i32 = NewType("v6i32", ValueRef)
v7i32 = NewType("v7i32", ValueRef)
v8i32 = NewType("v8i32", ValueRef)
v9i32 = NewType("v9i32", ValueRef)
v10i32 = NewType("v10i32", ValueRef)
v11i32 = NewType("v11i32", ValueRef)
v12i32 = NewType("v12i32", ValueRef)
v16i32 = NewType("v16i32", ValueRef)
v32i32 = NewType("v32i32", ValueRef)
v64i32 = NewType("v64i32", ValueRef)
v128i32 = NewType("v128i32", ValueRef)
v256i32 = NewType("v256i32", ValueRef)
v512i32 = NewType("v512i32", ValueRef)
v1024i32 = NewType("v1024i32", ValueRef)
v2048i32 = NewType("v2048i32", ValueRef)
v1i64 = NewType("v1i64", ValueRef)
v2i64 = NewType("v2i64", ValueRef)
v3i64 = NewType("v3i64", ValueRef)
v4i64 = NewType("v4i64", ValueRef)
v8i64 = NewType("v8i64", ValueRef)
v16i64 = NewType("v16i64", ValueRef)
v32i64 = NewType("v32i64", ValueRef)
v64i64 = NewType("v64i64", ValueRef)
v128i64 = NewType("v128i64", ValueRef)
v256i64 = NewType("v256i64", ValueRef)
v1i128 = NewType("v1i128", ValueRef)
v1f16 = NewType("v1f16", ValueRef)
v2f16 = NewType("v2f16", ValueRef)
v3f16 = NewType("v3f16", ValueRef)
v4f16 = NewType("v4f16", ValueRef)
v8f16 = NewType("v8f16", ValueRef)
v16f16 = NewType("v16f16", ValueRef)
v32f16 = NewType("v32f16", ValueRef)
v64f16 = NewType("v64f16", ValueRef)
v128f16 = NewType("v128f16", ValueRef)
v256f16 = NewType("v256f16", ValueRef)
v512f16 = NewType("v512f16", ValueRef)
v1bf16 = NewType("v1bf16", ValueRef)
v2bf16 = NewType("v2bf16", ValueRef)
v3bf16 = NewType("v3bf16", ValueRef)
v4bf16 = NewType("v4bf16", ValueRef)
v8bf16 = NewType("v8bf16", ValueRef)
v16bf16 = NewType("v16bf16", ValueRef)
v32bf16 = NewType("v32bf16", ValueRef)
v64bf16 = NewType("v64bf16", ValueRef)
v128bf16 = NewType("v128bf16", ValueRef)
v1f32 = NewType("v1f32", ValueRef)
v2f32 = NewType("v2f32", ValueRef)
v3f32 = NewType("v3f32", ValueRef)
v4f32 = NewType("v4f32", ValueRef)
v5f32 = NewType("v5f32", ValueRef)
v6f32 = NewType("v6f32", ValueRef)
v7f32 = NewType("v7f32", ValueRef)
v8f32 = NewType("v8f32", ValueRef)
v9f32 = NewType("v9f32", ValueRef)
v10f32 = NewType("v10f32", ValueRef)
v11f32 = NewType("v11f32", ValueRef)
v12f32 = NewType("v12f32", ValueRef)
v16f32 = NewType("v16f32", ValueRef)
v32f32 = NewType("v32f32", ValueRef)
v64f32 = NewType("v64f32", ValueRef)
v128f32 = NewType("v128f32", ValueRef)
v256f32 = NewType("v256f32", ValueRef)
v512f32 = NewType("v512f32", ValueRef)
v1024f32 = NewType("v1024f32", ValueRef)
v2048f32 = NewType("v2048f32", ValueRef)
v1f64 = NewType("v1f64", ValueRef)
v2f64 = NewType("v2f64", ValueRef)
v3f64 = NewType("v3f64", ValueRef)
v4f64 = NewType("v4f64", ValueRef)
v8f64 = NewType("v8f64", ValueRef)
v16f64 = NewType("v16f64", ValueRef)
v32f64 = NewType("v32f64", ValueRef)
v64f64 = NewType("v64f64", ValueRef)
v128f64 = NewType("v128f64", ValueRef)
v256f64 = NewType("v256f64", ValueRef)
vararg = NewType("vararg", ValueRef)
metadata = NewType("metadata", ValueRef)
_T = TypeVar('_T')
class LLVMQualPointerType(Generic[_T]):
pass
local_ptr = LLVMQualPointerType[3]
global_ptr = LLVMQualPointerType[1]
AMDGPUBufferRsrcTy = LLVMQualPointerType[8];
class LLVMMatchType(Generic[_T]):
pass
"""
),
file=amdgcn_f,
)
intrins = RecordKeeper().parse_td(
str(llvm_include_root / "llvm" / "IR" / "Intrinsics.td"),
include_dirs=[str(llvm_include_root)],
)
int_regex = re.compile(r"_i(\d+)")
fp_regex = re.compile(r"_f(\d+)")

for d in intrins.defs:
intr = intrins.defs[d]
if intr.name.startswith("int_amdgcn") and intr.type.as_string != "ClangBuiltin":
arg_types = []
ret_types = []
for p in intr.values.ParamTypes.value:
p_s = p.as_string
if p_s.startswith("anon"):
p_s = p.type.as_string
if p_s == "LLVMMatchType":
p_s += f"[{p.def_.values.Number.value.value}]"
elif p_s == "LLVMQualPointerType":
_, addr_space = p.def_.values.Sig.value.values
p_s += f"[{addr_space}]"
else:
raise NotImplemented(f"unsupported {p_s=}")
else:
p_s = re.sub(int_regex, r"_int\1", p_s)
p_s = re.sub(fp_regex, r"_fp\1", p_s)
p_s = p_s.replace("llvm_", "").replace("_ty", "")

if p_s == "ptr":
p_s = "pointer"

arg_types.append(p_s)
for p in intr.values.RetTypes.value:
ret_types.append(p.as_string)

ret_str = ""
if len(ret_types):
ret_str = "return "

intr_name = d.replace("int_amdgcn_", "")
llvm_intr_name = f"llvm.amdgcn.{intr_name.replace('_', '.')}"
arg_names = "abcdefghijklmnopqrstuvwxyz"[: len(arg_types)]
fn_args_str = ", ".join([f"{n}: {t}" for n, t in zip(arg_names, arg_types)])
call_args_str = ", ".join(arg_names)
if fn_args_str:
fn_args_str = f"{fn_args_str}, "
call_args_str = f"{call_args_str}, "

intr_name = __normalize_python_kws.get(intr_name, intr_name)

print(
dedent(
f"""
def {intr_name}({fn_args_str}name=""):
{ret_str}call_intrinsic("{llvm_intr_name}", {call_args_str}name=name)
"""
),
file=amdgcn_f,
)

amdgcn_f.flush()
amdgcn_f.close()


def generate_nb_bindings(header_root: Path, output_root: Path):
pp_dir = output_root / "pp"
pp_dir.mkdir(parents=True, exist_ok=True)
for header_f in Path(header_root).rglob("*.h"):
for header_f in header_root.rglob("*.h"):
with open(header_f) as ff:
orig_code = ff.read()
pp_header_f = pp_dir / header_f.name
Expand Down Expand Up @@ -121,7 +372,10 @@ def main(header_root, output_root):

if __name__ == "__main__":
parser = argparse.ArgumentParser(prog="eudsl-llvmpy-generate")
parser.add_argument("headers_root", type=Path)
parser.add_argument("llvm_include_root", type=Path)
parser.add_argument("output_root", type=Path)
parser.add_argument("llvmpy_module_dir", type=Path)
args = parser.parse_args()
main(args.headers_root, args.output_root)

generate_nb_bindings(args.llvm_include_root / "llvm-c", args.output_root)
generate_amdgcn_intrinsics(args.llvm_include_root, args.llvmpy_module_dir)
Loading

0 comments on commit 10a2815

Please sign in to comment.