Skip to content

Commit

Permalink
feat: expose gpu memory allocation options (#589)
Browse files Browse the repository at this point in the history
* feat: expose more XLA GPU options to the user

* feat: check for bazel install

* fix: load env vars before

* Update deps/build_local.jl

Co-authored-by: Mosè Giordano <[email protected]>

* Update Project.toml

---------

Co-authored-by: Mosè Giordano <[email protected]>
Co-authored-by: William Moses <[email protected]>
  • Loading branch information
3 people authored Jan 21, 2025
1 parent 66910bf commit 678b90d
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 4 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ PythonCall = "0.9"
Random = "1.10"
Random123 = "1.7"
ReactantCore = "0.1.4"
Reactant_jll = "0.0.46"
Reactant_jll = "0.0.47"
Scratch = "1.2"
Sockets = "1.10"
SpecialFunctions = "2.4"
Expand Down
4 changes: 4 additions & 0 deletions deps/ReactantExtra/API.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,15 @@ extern "C" PjRtClient *MakeCPUClient(uint8_t asynchronous, int node_id,
extern "C" PjRtClient *MakeGPUClient(int node_id, int num_nodes,
int *allowed_devices,
int num_allowed_devices,
double memory_fraction,
bool preallocate,
const char *platform_name,
const char **error) {
GpuClientOptions options;
// options.kv_store = "etcd";
// options.allocator_config =
options.allocator_config.preallocate = preallocate;
options.allocator_config.memory_fraction = memory_fraction;
options.node_id = node_id;
options.num_nodes = num_nodes;
options.allowed_devices =
Expand Down
14 changes: 12 additions & 2 deletions deps/build_local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -80,10 +80,20 @@ end

@info "Building JLL with backend $(build_backend)"

bazel_cmd = if !isnothing(Sys.which("bazelisk"))
"bazelisk"
elseif !isnothing(Sys.which("bazel"))
"bazel"
else
error("Could not find `bazel` or `bazelisk` in PATH!")
end

@info "Building JLL with $(bazel_cmd)"

if isempty(arg)
run(
Cmd(
`bazel build -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1])
`$(bazel_cmd) build -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1])
--repo_env HERMETIC_PYTHON_VERSION="3.10"
--check_visibility=false --verbose_failures :libReactantExtra.so`;
dir=source_dir,
Expand All @@ -92,7 +102,7 @@ if isempty(arg)
else
run(
Cmd(
`bazel build $(arg) -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1])
`$(bazel_cmd) build $(arg) -c $(build_kind) --action_env=JULIA=$(Base.julia_cmd().exec[1])
--repo_env=GCC_HOST_COMPILER_PATH=/usr/bin/gcc
--repo_env=CC=/home/wmoses/llvms/llvm16-r/clang+llvm-16.0.2-x86_64-linux-gnu-ubuntu-22.04/bin/clang
--repo_env HERMETIC_PYTHON_VERSION="3.10"
Expand Down
19 changes: 18 additions & 1 deletion src/XLA.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@ module XLA

import ...MLIR

const XLA_REACTANT_GPU_MEM_FRACTION = Ref{Float64}(0.75)
const XLA_REACTANT_GPU_PREALLOCATE = Ref{Bool}(true)

function LLVMclopts(opts...)
args = ["", opts...]
@ccall MLIR.API.mlir_c.ReactantLLVMParseCommandLineOptions(
Expand Down Expand Up @@ -70,11 +73,13 @@ function GPUClient(node_id=0, num_nodes=1, platform="gpu")
client = ccall(
f,
Ptr{Cvoid},
(Cint, Cint, Ptr{Cvoid}, Cint, Cstring, Ptr{Cstring}),
(Cint, Cint, Ptr{Cvoid}, Cint, Cdouble, Bool, Cstring, Ptr{Cstring}),
node_id,
num_nodes,
C_NULL,
0,
XLA_REACTANT_GPU_MEM_FRACTION[],
XLA_REACTANT_GPU_PREALLOCATE[],
platform,
refstr,
)
Expand Down Expand Up @@ -124,6 +129,18 @@ function __init__()
backends["cpu"] = cpu
default_backend[] = cpu

if haskey(ENV, "XLA_REACTANT_GPU_MEM_FRACTION")
XLA_REACTANT_GPU_MEM_FRACTION[] = parse(
Float64, ENV["XLA_REACTANT_GPU_MEM_FRACTION"]
)
@debug "XLA_REACTANT_GPU_MEM_FRACTION: " XLA_REACTANT_GPU_MEM_FRACTION[]
end

if haskey(ENV, "XLA_REACTANT_GPU_PREALLOCATE")
XLA_REACTANT_GPU_PREALLOCATE[] = parse(Bool, ENV["XLA_REACTANT_GPU_PREALLOCATE"])
@debug "XLA_REACTANT_GPU_PREALLOCATE: " XLA_REACTANT_GPU_PREALLOCATE[]
end

@static if !Sys.isapple()
if isfile("/usr/lib/libtpu.so")
dataset_dir = @get_scratch!("libtpu")
Expand Down

0 comments on commit 678b90d

Please sign in to comment.