Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof-of-concept: A JAX callback that operates directly on device buffers. #23252

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions xla/ffi/ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,8 @@ class AnyBuffer {
return se::DeviceMemoryBase(untyped_data(), size_bytes());
}

const XLA_FFI_Buffer* buf() const { return buf_; }

private:
const XLA_FFI_Buffer* buf_;
};
Expand Down
5 changes: 5 additions & 0 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@ cc_library(
cc_library(
name = "py_client",
srcs = [
"ffi.cc",
"py_array.cc",
"py_client.cc",
"py_compile_only_client.cc",
Expand All @@ -300,6 +301,7 @@ cc_library(
"to_ifrt_sharding.cc",
],
hdrs = [
"ffi.h",
"py_array.h",
"py_client.h",
"py_compile_only_client.h",
Expand Down Expand Up @@ -370,6 +372,9 @@ cc_library(
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/ffi",
"//xla/ffi:ffi_api",
"//xla/ffi/api:c_api",
"//xla/hlo/builder:xla_builder",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/builder/lib:arithmetic",
Expand Down
161 changes: 161 additions & 0 deletions xla/python/ffi.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/python/ffi.h"

#include <cstddef>
#include <cstdint>
#include <vector>

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "llvm/Support/Casting.h"
#include "nanobind/nanobind.h"
#include "xla/ffi/api/c_api.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/pjrt/host_callback.h"
#include "xla/pjrt/status_casters.h"
#include "xla/python/callback.h"
#include "xla/python/ifrt/host_callback.h"
#include "xla/python/py_host_callback.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"

namespace xla {

namespace nb = nanobind;

class PyContext {
public:
enum Stage {
kInstantiate = XLA_FFI_ExecutionStage_INSTANTIATE,
kPrepare = XLA_FFI_ExecutionStage_PREPARE,
kInitialize = XLA_FFI_ExecutionStage_INITIALIZE,
kExecute = XLA_FFI_ExecutionStage_EXECUTE,
};

PyContext(const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx,
XLA_FFI_ExecutionStage stage)
: api_(api), ctx_(ctx), stage_(stage) {}

Stage stage() const { return static_cast<Stage>(stage_); }
absl::StatusOr<void*> stream() const {
XLA_FFI_Stream_Get_Args args;
args.struct_size = XLA_FFI_Stream_Get_Args_STRUCT_SIZE;
args.extension_start = nullptr;
args.ctx = ctx_;
args.stream = nullptr;
if (XLA_FFI_Error* error = api_->XLA_FFI_Stream_Get(&args)) {
return ffi::TakeStatus(error);
}
return args.stream;
}

private:
const XLA_FFI_Api* api_;
XLA_FFI_ExecutionContext* ctx_;
XLA_FFI_ExecutionStage stage_;
};

class PyBuffer {
public:
explicit PyBuffer(const XLA_FFI_Buffer* buf) : buf_(buf) {}
void* data() const { return buf_->data; }

private:
const XLA_FFI_Buffer* buf_;
};

template <XLA_FFI_ExecutionStage stage>
absl::Status FfiCallbackImpl(
const XLA_FFI_Api* api, XLA_FFI_ExecutionContext* ctx,
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>* callbacks,
uint64_t index, ffi::RemainingArgs args, ffi::RemainingRets rets) {
if (index >= callbacks->size()) {
return absl::InvalidArgumentError("Callback index out of range.");
}
auto loaded_callback = llvm::dyn_cast_or_null<PyCpuLoadedHostCallback>(
callbacks->at(index).get());
if (loaded_callback == nullptr) {
return absl::InternalError(
"Expected a PyCpuLoadedHostCallback, got something else.");
}
CpuCallback* callback = loaded_callback->cpu_callback();

nb::gil_scoped_acquire gil;
auto nb_args =
nb::steal<nb::tuple>(PyTuple_New(1 + args.size() + rets.size()));

PyContext py_ctx(api, ctx, stage);
PyTuple_SET_ITEM(nb_args.ptr(), 0, nb::cast(py_ctx).release().ptr());

size_t offset = 1;
for (size_t i = 0; i < args.size(); ++i, ++offset) {
TF_ASSIGN_OR_RETURN(auto arg, args.get<ffi::AnyBuffer>(i));
PyBuffer py_buffer(arg.buf());
PyTuple_SET_ITEM(nb_args.ptr(), offset,
nb::cast(py_buffer).release().ptr());
}

for (size_t i = 0; i < rets.size(); ++i, ++offset) {
TF_ASSIGN_OR_RETURN(auto ret, rets.get<ffi::AnyBuffer>(i));
PyBuffer py_buffer(ret->buf());
PyTuple_SET_ITEM(nb_args.ptr(), offset,
nb::cast(py_buffer).release().ptr());
}

EnterHostCallback();
absl::StatusOr<nb::tuple> maybe_result_tuple = callback->FfiCall(nb_args);
LeaveHostCallback();
TF_RETURN_IF_ERROR(maybe_result_tuple.status());

return absl::OkStatus();
}

XLA_FFI_DEFINE_HANDLER(
kFfiCallback, FfiCallbackImpl<XLA_FFI_ExecutionStage_EXECUTE>,
ffi::Ffi::Bind()
.Ctx<ffi::FfiApi>()
.Ctx<ffi::FfiExecutionContext>()
.Ctx<ffi::UserData<
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>>>()
.Attr<uint64_t>("index")
.RemainingArgs()
.RemainingRets());
XLA_FFI_REGISTER_HANDLER(ffi::GetXlaFfiApi(), "xla_python_buffer_callback",
"Host", kFfiCallback);

void BuildFfiSubmodule(nb::module_& m) {
nb::module_ ffi_module =
m.def_submodule("ffi", "Python bindings for the XLA FFI.");

nb::class_<PyBuffer> buffer(ffi_module, "Buffer");
buffer.def("data", &PyBuffer::data);

nb::enum_<PyContext::Stage>(ffi_module, "ExecutionStage")
.value("INSTANTIATE", PyContext::Stage::kInstantiate)
.value("PREPARE", PyContext::Stage::kPrepare)
.value("INITIALIZE", PyContext::Stage::kInitialize)
.value("EXECUTE", PyContext::Stage::kExecute)
.export_values();

nb::class_<PyContext> context(ffi_module, "ExecutionContext");
context.def("stage", &PyContext::stage);
context.def("stream", ValueOrThrowWrapper(&PyContext::stream));
}

} // namespace xla
27 changes: 27 additions & 0 deletions xla/python/ffi.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
/* Copyright 2025 The OpenXLA Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_PYTHON_FFI_H_
#define XLA_PYTHON_FFI_H_

#include "nanobind/nanobind.h"

namespace xla {

void BuildFfiSubmodule(nanobind::module_& m);

} // namespace xla

#endif // XLA_PYTHON_FFI_H_
2 changes: 2 additions & 0 deletions xla/python/xla.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ limitations under the License.
#include "xla/python/config.h"
#include "xla/python/custom_call_sharding.h"
#include "xla/python/dlpack.h"
#include "xla/python/ffi.h"
#include "xla/python/guard_lib.h"
#include "xla/python/jax_jit.h"
#include "xla/python/logging.h" // IWYU pragma: keep
Expand Down Expand Up @@ -602,6 +603,7 @@ NB_MODULE(xla_extension, m) {
BuildMlirSubmodule(m);
BuildSdySubmodule(m);
BuildCustomCallShardingPybindAPI(m);
BuildFfiSubmodule(m);
#if defined(__linux__)
aux::RegisterTransferServerTypes(m);
#endif // defined(__linux__)
Expand Down
1 change: 1 addition & 0 deletions xla/python/xla_extension/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ from typing import (
import numpy as np

from . import config
from . import ffi
from . import guard_lib
from . import ifrt_programs
from . import ifrt_proxy
Expand Down
33 changes: 33 additions & 0 deletions xla/python/xla_extension/ffi.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2025 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

import ctypes
import enum


class Buffer:
def data(self) -> ctypes.c_void_p: ...


class ExecutionStage(enum.IntEnum):
INSTANTIATE: int
PREPARE: int
INITIALIZE: int
EXECUTE: int


class ExecutionContext:
def stage(self) -> ExecutionStage: ...
def stream(self) -> ctypes.c_void_p: ...