From 678b90de9000707d3ba9f773287386074cf13f85 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 21 Jan 2025 16:25:38 -0500 Subject: [PATCH] feat: expose gpu memory allocation options (#589) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: expose more XLA GPU options to the user * feat: check for bazel install * fix: load env vars before * Update deps/build_local.jl Co-authored-by: Mosè Giordano <765740+giordano@users.noreply.github.com> * Update Project.toml --------- Co-authored-by: Mosè Giordano <765740+giordano@users.noreply.github.com> Co-authored-by: William Moses --- Project.toml | 2 +- deps/ReactantExtra/API.cpp | 4 ++++ deps/build_local.jl | 14 ++++++++++++-- src/XLA.jl | 19 ++++++++++++++++++- 4 files changed, 35 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index 46e072d40..7c973145f 100644 --- a/Project.toml +++ b/Project.toml @@ -70,7 +70,7 @@ PythonCall = "0.9" Random = "1.10" Random123 = "1.7" ReactantCore = "0.1.4" -Reactant_jll = "0.0.46" +Reactant_jll = "0.0.47" Scratch = "1.2" Sockets = "1.10" SpecialFunctions = "2.4" diff --git a/deps/ReactantExtra/API.cpp b/deps/ReactantExtra/API.cpp index 2c5819451..1a29765f7 100644 --- a/deps/ReactantExtra/API.cpp +++ b/deps/ReactantExtra/API.cpp @@ -275,11 +275,15 @@ extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id, extern "C" PjRtClient *MakeGPUClient(int node_id, int num_nodes, int *allowed_devices, int num_allowed_devices, + double memory_fraction, + bool preallocate, const char *platform_name, const char **error) { GpuClientOptions options; // options.kv_store = "etcd"; // options.allocator_config = + options.allocator_config.preallocate = preallocate; + options.allocator_config.memory_fraction = memory_fraction; options.node_id = node_id; options.num_nodes = num_nodes; options.allowed_devices = diff --git a/deps/build_local.jl b/deps/build_local.jl index 8a0c03e96..a5a9aa072 100644 --- a/deps/build_local.jl +++ b/deps/build_local.jl @@ -80,10 +80,20 @@ end @info "Building JLL with backend $(build_backend)" +bazel_cmd = if !isnothing(Sys.which("bazelisk")) + "bazelisk" +elseif !isnothing(Sys.which("bazel")) + "bazel" +else + error("Could not find `bazel` or `bazelisk` in PATH!") +end + +@info "Building JLL with $(bazel_cmd)" + if isempty(arg) run( Cmd( - `bazel build -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1]) + `$(bazel_cmd) build -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1]) --repo_env HERMETIC_PYTHON_VERSION="3.10" --check_visibility=false --verbose_failures :libReactantExtra.so`; dir=source_dir, @@ -92,7 +102,7 @@ if isempty(arg) else run( Cmd( - `bazel build $(arg) -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1]) + `$(bazel_cmd) build $(arg) -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1]) --repo_env=GCC_HOST_COMPILER_PATH=/usr/bin/gcc --repo_env=CC=/home/wmoses/llvms/llvm16-r/clang+llvm-16.0.2-x86_64-linux-gnu-ubuntu-22.04/bin/clang --repo_env HERMETIC_PYTHON_VERSION="3.10" diff --git a/src/XLA.jl b/src/XLA.jl index 3f06c8d02..3aaaf87c1 100644 --- a/src/XLA.jl +++ b/src/XLA.jl @@ -2,6 +2,9 @@ module XLA import ...MLIR +const XLA_REACTANT_GPU_MEM_FRACTION = Ref{Float64}(0.75) +const XLA_REACTANT_GPU_PREALLOCATE = Ref{Bool}(true) + function LLVMclopts(opts...) args = ["", opts...] @ccall MLIR.API.mlir_c.ReactantLLVMParseCommandLineOptions( @@ -70,11 +73,13 @@ function GPUClient(node_id=0, num_nodes=1, platform="gpu") client = ccall( f, Ptr{Cvoid}, - (Cint, Cint, Ptr{Cvoid}, Cint, Cstring, Ptr{Cstring}), + (Cint, Cint, Ptr{Cvoid}, Cint, Cdouble, Bool, Cstring, Ptr{Cstring}), node_id, num_nodes, C_NULL, 0, + XLA_REACTANT_GPU_MEM_FRACTION[], + XLA_REACTANT_GPU_PREALLOCATE[], platform, refstr, ) @@ -124,6 +129,18 @@ function __init__() backends["cpu"] = cpu default_backend[] = cpu + if haskey(ENV, "XLA_REACTANT_GPU_MEM_FRACTION") + XLA_REACTANT_GPU_MEM_FRACTION[] = parse( + Float64, ENV["XLA_REACTANT_GPU_MEM_FRACTION"] + ) + @debug "XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[] + end + + if haskey(ENV, "XLA_REACTANT_GPU_PREALLOCATE") + XLA_REACTANT_GPU_PREALLOCATE[] = parse(Bool, ENV["XLA_REACTANT_GPU_PREALLOCATE"]) + @debug "XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[] + end + @static if !Sys.isapple() if isfile("/usr/lib/libtpu.so") dataset_dir = @get_scratch!("libtpu")