Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Jax CPU/GPU callbacks with XLA's FFI. #23269

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 2 additions & 6 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,8 @@ cc_library(
":types",
"//xla:comparison_util",
"//xla:shape_util",
"//xla/ffi",
"//xla/ffi:ffi_api",
"//xla/ffi/api:ffi",
"//xla/pjrt:exceptions",
"//xla/pjrt:host_callback",
"//xla/pjrt:transpose",
Expand All @@ -539,7 +539,6 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@nanobind",
"@tsl//tsl/platform:errors",
],
Expand Down Expand Up @@ -573,18 +572,16 @@ cc_library(
":types",
"//xla:comparison_util",
"//xla:shape_util",
"//xla/ffi",
"//xla/ffi:ffi_api",
"//xla/ffi/api:ffi",
"//xla/pjrt:exceptions",
"//xla/pjrt:host_callback",
"//xla/pjrt:transpose",
"//xla/python/ifrt",
"//xla/service:custom_call_status",
"//xla/service:custom_call_target_registry",
"//xla/service:platform_util",
"//xla/tsl/concurrency:ref_count",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:inlined_vector",
Expand All @@ -593,7 +590,6 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@nanobind",
] + if_rocm(
["@local_config_rocm//rocm:rocm_headers"],
Expand Down
2 changes: 2 additions & 0 deletions xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ cc_library(
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/ffi:ffi_api",
"//xla/ffi/api:ffi",
"//xla/hlo/ir:hlo",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_compiler",
Expand Down
7 changes: 7 additions & 0 deletions xla/python/ifrt/client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,17 @@ limitations under the License.

#include "xla/python/ifrt/client.h"

#include "xla/ffi/api/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/python/ifrt/host_callback.h"

namespace xla {
namespace ifrt {

char Client::ID = 0;
ffi::TypeId ifrt::FfiLoadedHostCallbacks::id = {};
XLA_FFI_REGISTER_TYPE(ffi::GetXlaFfiApi(), "ffi_loaded_host_callbacks",
&ifrt::FfiLoadedHostCallbacks::id);

} // namespace ifrt
} // namespace xla
10 changes: 10 additions & 0 deletions xla/python/ifrt/host_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ limitations under the License.
#define XLA_PYTHON_IFRT_HOST_CALLBACK_H_

#include <string>
#include <vector>

#include "absl/status/statusor.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/ffi/api/ffi.h"
#include "xla/tsl/concurrency/ref_count.h"

namespace xla {
Expand Down Expand Up @@ -69,6 +71,14 @@ class LoadedHostCallback
static char ID; // NOLINT
};

struct FfiLoadedHostCallbacks {
static xla::ffi::TypeId id;
explicit FfiLoadedHostCallbacks(
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>* callbacks)
: callbacks(callbacks) {}
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>* callbacks;
};

} // namespace ifrt
} // namespace xla

Expand Down
1 change: 1 addition & 0 deletions xla/python/pjrt_ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ cc_library(
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/ffi:type_id_registry",
"//xla/hlo/ir:hlo",
"//xla/hlo/translate/mhlo_to_hlo:type_to_shape",
"//xla/pjrt:host_callback",
Expand Down
11 changes: 9 additions & 2 deletions xla/python/pjrt_ifrt/pjrt_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "xla/ffi/type_id_registry.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h"
#include "xla/pjrt/host_callback.h"
Expand Down Expand Up @@ -546,11 +547,16 @@ PjRtLoadedExecutable::Execute(absl::Span<tsl::RCReference<Array>> args,
opts.non_donatable_input_indices = options.non_donatable_input_indices;

auto context = std::make_shared<xla::ExecuteContext>();
auto callbacks = std::make_shared<FfiLoadedHostCallbacks>(
all_loaded_host_callbacks_.get());
auto platform_id = pjrt_loaded_executable_->client()->platform_id();
// Forward callbacks via FFI's ExecutionContext for CPU/GPU platforms only.
if (platform_id == CpuId() || platform_id == CudaId() ||
platform_id == RocmId() || platform_id == SyclId()) {
CHECK_OK(context->ffi_context().Insert(all_loaded_host_callbacks_.get()));
auto type_id =
xla::ffi::TypeIdRegistry::TypeId(FfiLoadedHostCallbacks::id.type_id);
CHECK_OK(context->ffi_context().Insert(
type_id, static_cast<void*>(callbacks.get())));
opts.context = context.get();
}

Expand Down Expand Up @@ -614,7 +620,8 @@ PjRtLoadedExecutable::Execute(absl::Span<tsl::RCReference<Array>> args,
// the execution finishes.
status.OnReady([all_loaded_host_callbacks = all_loaded_host_callbacks_,
host_callback_states = std::move(host_callback_states),
context = std::move(context)](absl::Status) mutable {
context = std::move(context),
callbacks = std::move(callbacks)](absl::Status) mutable {
all_loaded_host_callbacks.reset();
});
}
Expand Down
98 changes: 53 additions & 45 deletions xla/python/py_client_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "nanobind/nanobind.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/api/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/pjrt/host_callback.h"
#include "xla/pjrt/transpose.h"
Expand All @@ -42,39 +41,39 @@ limitations under the License.
#include "xla/python/py_host_callback.h"
#include "xla/python/types.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/platform/statusor.h"

namespace nb = nanobind;

namespace xla {

absl::Status XlaFfiPythonCpuCallback(
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>* callbacks,
uint64_t index, ffi::RemainingArgs args, ffi::RemainingRets rets) {
auto loaded_callback = llvm::dyn_cast_or_null<PyCpuLoadedHostCallback>(
callbacks->at(index).get());
if (loaded_callback == nullptr) {
return absl::InternalError(
"Expected a PyCpuLoadedHostCallback, got something else.");
}
ffi::Error XlaFfiPythonCpuCallback(ifrt::FfiLoadedHostCallbacks* callbacks,
uint64_t index, ffi::RemainingArgs args,
ffi::RemainingRets rets) {
auto loaded_callback = static_cast<PyCpuLoadedHostCallback*>(
callbacks->callbacks->at(index).get());
CpuCallback* callback = loaded_callback->cpu_callback();

nb::gil_scoped_acquire gil;
auto nb_args = nb::steal<nb::tuple>(PyTuple_New(args.size()));
for (size_t i = 0; i < args.size(); ++i) {
auto arg = args.get<ffi::AnyBuffer>(i);
auto ptype = arg->element_type();
auto ptype = static_cast<PrimitiveType>(arg->element_type());
if (ptype == TOKEN) {
PyTuple_SET_ITEM(nb_args.ptr(), i, nb::none().release().ptr());
} else {
TF_ASSIGN_OR_RETURN(auto dtype, PrimitiveTypeToNbDtype(ptype));
// We pass in data using default numpy layout i.e., std::nullopt.
auto array = nb_numpy_ndarray(dtype, arg->dimensions(), std::nullopt,
arg.value().untyped_data());
array.attr("flags").attr("writeable") = nb::bool_(false);
PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr());
continue;
}
auto maybe_dtype = PrimitiveTypeToNbDtype(ptype);
if (!maybe_dtype.ok()) {
return ffi::Error::Internal(maybe_dtype.status().ToString());
}
auto dtype = maybe_dtype.value();
auto dims = absl::Span<const int64_t>(arg->dimensions().begin(),
arg->dimensions().size());
// We pass in data using default numpy layout i.e., std::nullopt.
auto array =
nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data());
array.attr("flags").attr("writeable") = nb::bool_(false);
PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr());
}

EnterHostCallback();
Expand All @@ -83,49 +82,58 @@ absl::Status XlaFfiPythonCpuCallback(
absl::StatusOr<nb::tuple> maybe_result_tuple =
callback->FfiCall(std::move(nb_args));
LeaveHostCallback();
TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple);
if (!maybe_result_tuple.ok()) {
return ffi::Error::Internal(maybe_result_tuple.status().ToString());
}
auto result_tuple = maybe_result_tuple.value();

for (size_t i = 0; i < rets.size(); ++i) {
auto arg = rets.get<ffi::AnyBuffer>(i).value();
auto ptype = arg->element_type();
auto ret = rets.get<ffi::AnyBuffer>(i).value();
auto ptype = static_cast<PrimitiveType>(ret->element_type());
if (ptype == TOKEN) continue;
nb::object output =
nb::borrow<nb::object>(PyTuple_GetItem(result_tuple.ptr(), i));
nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output));
absl::Span<int64_t const> strides(
reinterpret_cast<const int64_t*>(array.strides()), array.ndim());
// We expect the output to be in default numpy layout.
TF_ASSIGN_OR_RETURN(auto expected_shape, ShapeUtil::MakeValidatedShape(
ptype, arg->dimensions()));
auto dims = absl::Span<const int64_t>(ret->dimensions().begin(),
ret->dimensions().size());
auto maybe_expected_shape = ShapeUtil::MakeValidatedShape(ptype, dims);
if (!maybe_expected_shape.ok()) {
return ffi::Error::Internal(maybe_expected_shape.status().ToString());
}
auto expected_shape = maybe_expected_shape.value();
auto expected_strides = ByteStridesForShape(expected_shape);
if (strides == expected_strides) {
std::memcpy(arg->untyped_data(), array.data(), arg->size_bytes());
} else {
xla::TransposePlan::Options options;
options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype);
absl::Span<int64_t const> dims(
reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
options.dims = dims;
absl::InlinedVector<int64_t, 4> reversed_layout;
reversed_layout.resize(expected_shape.dimensions_size());
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
reversed_layout.begin());
options.permutation = reversed_layout;
options.input_layout = xla::TransposePlan::Striding{strides};
TF_ASSIGN_OR_RETURN(auto plan,
callback->transpose_cache().GetOrCreate(options));
plan->Execute(array.data(), arg->untyped_data());
std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes());
continue;
}
xla::TransposePlan::Options options;
options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype);
options.dims = absl::Span<const int64_t>(
reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
absl::InlinedVector<int64_t, 4> reversed_layout;
reversed_layout.resize(expected_shape.dimensions_size());
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
reversed_layout.begin());
options.permutation = reversed_layout;
options.input_layout = xla::TransposePlan::Striding{strides};
auto maybe_plan = callback->transpose_cache().GetOrCreate(options);
if (!maybe_plan.ok()) {
return ffi::Error::Internal(maybe_plan.status().ToString());
}
auto plan = maybe_plan.value();
plan->Execute(array.data(), ret->untyped_data());
}

return absl::OkStatus();
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
kXlaFfiPythonCpuCallback, XlaFfiPythonCpuCallback,
ffi::Ffi::Bind()
.Ctx<ffi::UserData<
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>>>()
.Ctx<ffi::UserData<ifrt::FfiLoadedHostCallbacks>>()
.Attr<uint64_t>("index")
.RemainingArgs()
.RemainingRets());
Expand Down
2 changes: 1 addition & 1 deletion xla/python/py_client_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.
#ifndef XLA_PYTHON_PY_CLIENT_CPU_H_
#define XLA_PYTHON_PY_CLIENT_CPU_H_

#include "xla/ffi/ffi.h"
#include "xla/ffi/api/ffi.h"

namespace xla {

Expand Down
Loading
Loading