Skip to content

Commit

Permalink
Fix FFI callbacks for OSS.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 731931490
  • Loading branch information
danielsuo authored and Google-ML-Automation committed Feb 28, 2025
1 parent 644bb8b commit e1c8822
Show file tree
Hide file tree
Showing 10 changed files with 152 additions and 112 deletions.
10 changes: 4 additions & 6 deletions xla/python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ cc_library(
"//xla:types",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/ffi:ffi_api",
"//xla/ffi/api:ffi",
"//xla/hlo/builder:xla_builder",
"//xla/hlo/builder:xla_computation",
"//xla/hlo/builder/lib:arithmetic",
Expand Down Expand Up @@ -522,8 +524,8 @@ cc_library(
":types",
"//xla:comparison_util",
"//xla:shape_util",
"//xla/ffi",
"//xla/ffi:ffi_api",
"//xla/ffi/api:ffi",
"//xla/pjrt:exceptions",
"//xla/pjrt:host_callback",
"//xla/pjrt:transpose",
Expand All @@ -539,7 +541,6 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@nanobind",
"@tsl//tsl/platform:errors",
],
Expand Down Expand Up @@ -573,18 +574,16 @@ cc_library(
":types",
"//xla:comparison_util",
"//xla:shape_util",
"//xla/ffi",
"//xla/ffi:ffi_api",
"//xla/ffi/api:ffi",
"//xla/pjrt:exceptions",
"//xla/pjrt:host_callback",
"//xla/pjrt:transpose",
"//xla/python/ifrt",
"//xla/service:custom_call_status",
"//xla/service:custom_call_target_registry",
"//xla/service:platform_util",
"//xla/tsl/concurrency:ref_count",
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/container:inlined_vector",
Expand All @@ -593,7 +592,6 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"@llvm-project//llvm:Support",
"@nanobind",
] + if_rocm(
["@local_config_rocm//rocm:rocm_headers"],
Expand Down
1 change: 1 addition & 0 deletions xla/python/ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ cc_library(
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/ffi/api:ffi",
"//xla/hlo/ir:hlo",
"//xla/pjrt:pjrt_client",
"//xla/pjrt:pjrt_compiler",
Expand Down
10 changes: 10 additions & 0 deletions xla/python/ifrt/host_callback.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@ limitations under the License.
#define XLA_PYTHON_IFRT_HOST_CALLBACK_H_

#include <string>
#include <vector>

#include "absl/status/statusor.h"
#include "llvm/Support/ExtensibleRTTI.h"
#include "xla/ffi/api/ffi.h"
#include "xla/tsl/concurrency/ref_count.h"

namespace xla {
Expand Down Expand Up @@ -69,6 +71,14 @@ class LoadedHostCallback
static char ID; // NOLINT
};

struct FfiLoadedHostCallbacks {
static xla::ffi::TypeId id;
explicit FfiLoadedHostCallbacks(
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>* callbacks)
: callbacks(callbacks) {}
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>* callbacks;
};

} // namespace ifrt
} // namespace xla

Expand Down
1 change: 1 addition & 0 deletions xla/python/pjrt_ifrt/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ cc_library(
"//xla:status_macros",
"//xla:util",
"//xla:xla_data_proto_cc",
"//xla/ffi:type_id_registry",
"//xla/hlo/ir:hlo",
"//xla/hlo/translate/mhlo_to_hlo:type_to_shape",
"//xla/pjrt:host_callback",
Expand Down
11 changes: 9 additions & 2 deletions xla/python/pjrt_ifrt/pjrt_executable.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ limitations under the License.
#include "llvm/Support/Casting.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/IR/BuiltinOps.h"
#include "xla/ffi/type_id_registry.h"
#include "xla/hlo/ir/hlo_sharding.h"
#include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h"
#include "xla/pjrt/host_callback.h"
Expand Down Expand Up @@ -546,11 +547,16 @@ PjRtLoadedExecutable::Execute(absl::Span<tsl::RCReference<Array>> args,
opts.non_donatable_input_indices = options.non_donatable_input_indices;

auto context = std::make_shared<xla::ExecuteContext>();
auto callbacks = std::make_shared<FfiLoadedHostCallbacks>(
all_loaded_host_callbacks_.get());
auto platform_id = pjrt_loaded_executable_->client()->platform_id();
// Forward callbacks via FFI's ExecutionContext for CPU/GPU platforms only.
if (platform_id == CpuId() || platform_id == CudaId() ||
platform_id == RocmId() || platform_id == SyclId()) {
CHECK_OK(context->ffi_context().Insert(all_loaded_host_callbacks_.get()));
auto type_id =
xla::ffi::TypeIdRegistry::TypeId(FfiLoadedHostCallbacks::id.type_id);
CHECK_OK(context->ffi_context().Insert(
type_id, static_cast<void*>(callbacks.get())));
opts.context = context.get();
}

Expand Down Expand Up @@ -614,7 +620,8 @@ PjRtLoadedExecutable::Execute(absl::Span<tsl::RCReference<Array>> args,
// the execution finishes.
status.OnReady([all_loaded_host_callbacks = all_loaded_host_callbacks_,
host_callback_states = std::move(host_callback_states),
context = std::move(context)](absl::Status) mutable {
context = std::move(context),
callbacks = std::move(callbacks)](absl::Status) mutable {
all_loaded_host_callbacks.reset();
});
}
Expand Down
6 changes: 6 additions & 0 deletions xla/python/py_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ limitations under the License.
#include "nanobind/stl/unique_ptr.h" // IWYU pragma: keep
#include "nanobind/stl/variant.h" // IWYU pragma: keep
#include "nanobind/stl/vector.h" // IWYU pragma: keep
#include "xla/ffi/api/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/literal.h"
#include "xla/pjrt/exceptions.h"
#include "xla/pjrt/mlir_to_hlo.h"
Expand Down Expand Up @@ -666,6 +668,10 @@ absl::StatusOr<nb::object> PyClient::GetEmitPythonCallback(
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback",
&XlaPythonCpuCallback);

ffi::TypeId ifrt::FfiLoadedHostCallbacks::id = {};
XLA_FFI_REGISTER_TYPE(ffi::GetXlaFfiApi(), "ffi_loaded_host_callbacks",
&ifrt::FfiLoadedHostCallbacks::id);

/* static */ int PyClient::tp_traverse(PyObject* self, visitproc visit,
void* arg) {
PyClient* c = nb::inst_ptr<PyClient>(self);
Expand Down
98 changes: 53 additions & 45 deletions xla/python/py_client_cpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ limitations under the License.
#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/types/span.h"
#include "llvm/Support/Casting.h"
#include "nanobind/nanobind.h"
#include "xla/ffi/ffi.h"
#include "xla/ffi/api/ffi.h"
#include "xla/ffi/ffi_api.h"
#include "xla/pjrt/host_callback.h"
#include "xla/pjrt/transpose.h"
Expand All @@ -42,39 +41,39 @@ limitations under the License.
#include "xla/python/py_host_callback.h"
#include "xla/python/types.h"
#include "xla/shape_util.h"
#include "xla/tsl/concurrency/ref_count.h"
#include "xla/tsl/platform/statusor.h"

namespace nb = nanobind;

namespace xla {

absl::Status XlaFfiPythonCpuCallback(
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>* callbacks,
uint64_t index, ffi::RemainingArgs args, ffi::RemainingRets rets) {
auto loaded_callback = llvm::dyn_cast_or_null<PyCpuLoadedHostCallback>(
callbacks->at(index).get());
if (loaded_callback == nullptr) {
return absl::InternalError(
"Expected a PyCpuLoadedHostCallback, got something else.");
}
ffi::Error XlaFfiPythonCpuCallback(ifrt::FfiLoadedHostCallbacks* callbacks,
uint64_t index, ffi::RemainingArgs args,
ffi::RemainingRets rets) {
auto loaded_callback = static_cast<PyCpuLoadedHostCallback*>(
callbacks->callbacks->at(index).get());
CpuCallback* callback = loaded_callback->cpu_callback();

nb::gil_scoped_acquire gil;
auto nb_args = nb::steal<nb::tuple>(PyTuple_New(args.size()));
for (size_t i = 0; i < args.size(); ++i) {
auto arg = args.get<ffi::AnyBuffer>(i);
auto ptype = arg->element_type();
auto ptype = static_cast<PrimitiveType>(arg->element_type());
if (ptype == TOKEN) {
PyTuple_SET_ITEM(nb_args.ptr(), i, nb::none().release().ptr());
} else {
TF_ASSIGN_OR_RETURN(auto dtype, PrimitiveTypeToNbDtype(ptype));
// We pass in data using default numpy layout i.e., std::nullopt.
auto array = nb_numpy_ndarray(dtype, arg->dimensions(), std::nullopt,
arg.value().untyped_data());
array.attr("flags").attr("writeable") = nb::bool_(false);
PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr());
continue;
}
auto maybe_dtype = PrimitiveTypeToNbDtype(ptype);
if (!maybe_dtype.ok()) {
return ffi::Error::Internal(maybe_dtype.status().ToString());
}
auto dtype = maybe_dtype.value();
auto dims = absl::Span<const int64_t>(arg->dimensions().begin(),
arg->dimensions().size());
// We pass in data using default numpy layout i.e., std::nullopt.
auto array =
nb_numpy_ndarray(dtype, dims, std::nullopt, arg.value().untyped_data());
array.attr("flags").attr("writeable") = nb::bool_(false);
PyTuple_SET_ITEM(nb_args.ptr(), i, array.release().ptr());
}

EnterHostCallback();
Expand All @@ -83,49 +82,58 @@ absl::Status XlaFfiPythonCpuCallback(
absl::StatusOr<nb::tuple> maybe_result_tuple =
callback->FfiCall(std::move(nb_args));
LeaveHostCallback();
TF_ASSIGN_OR_RETURN(auto result_tuple, maybe_result_tuple);
if (!maybe_result_tuple.ok()) {
return ffi::Error::Internal(maybe_result_tuple.status().ToString());
}
auto result_tuple = maybe_result_tuple.value();

for (size_t i = 0; i < rets.size(); ++i) {
auto arg = rets.get<ffi::AnyBuffer>(i).value();
auto ptype = arg->element_type();
auto ret = rets.get<ffi::AnyBuffer>(i).value();
auto ptype = static_cast<PrimitiveType>(ret->element_type());
if (ptype == TOKEN) continue;
nb::object output =
nb::borrow<nb::object>(PyTuple_GetItem(result_tuple.ptr(), i));
nb_numpy_ndarray array = nb_numpy_ndarray::ensure(std::move(output));
absl::Span<int64_t const> strides(
reinterpret_cast<const int64_t*>(array.strides()), array.ndim());
// We expect the output to be in default numpy layout.
TF_ASSIGN_OR_RETURN(auto expected_shape, ShapeUtil::MakeValidatedShape(
ptype, arg->dimensions()));
auto dims = absl::Span<const int64_t>(ret->dimensions().begin(),
ret->dimensions().size());
auto maybe_expected_shape = ShapeUtil::MakeValidatedShape(ptype, dims);
if (!maybe_expected_shape.ok()) {
return ffi::Error::Internal(maybe_expected_shape.status().ToString());
}
auto expected_shape = maybe_expected_shape.value();
auto expected_strides = ByteStridesForShape(expected_shape);
if (strides == expected_strides) {
std::memcpy(arg->untyped_data(), array.data(), arg->size_bytes());
} else {
xla::TransposePlan::Options options;
options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype);
absl::Span<int64_t const> dims(
reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
options.dims = dims;
absl::InlinedVector<int64_t, 4> reversed_layout;
reversed_layout.resize(expected_shape.dimensions_size());
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
reversed_layout.begin());
options.permutation = reversed_layout;
options.input_layout = xla::TransposePlan::Striding{strides};
TF_ASSIGN_OR_RETURN(auto plan,
callback->transpose_cache().GetOrCreate(options));
plan->Execute(array.data(), arg->untyped_data());
std::memcpy(ret->untyped_data(), array.data(), ret->size_bytes());
continue;
}
xla::TransposePlan::Options options;
options.elem_size_in_bytes = xla::primitive_util::ByteWidth(ptype);
options.dims = absl::Span<const int64_t>(
reinterpret_cast<const int64_t*>(array.shape()), array.ndim());
absl::InlinedVector<int64_t, 4> reversed_layout;
reversed_layout.resize(expected_shape.dimensions_size());
absl::c_reverse_copy(expected_shape.layout().minor_to_major(),
reversed_layout.begin());
options.permutation = reversed_layout;
options.input_layout = xla::TransposePlan::Striding{strides};
auto maybe_plan = callback->transpose_cache().GetOrCreate(options);
if (!maybe_plan.ok()) {
return ffi::Error::Internal(maybe_plan.status().ToString());
}
auto plan = maybe_plan.value();
plan->Execute(array.data(), ret->untyped_data());
}

return absl::OkStatus();
return ffi::Error::Success();
}

XLA_FFI_DEFINE_HANDLER_SYMBOL(
kXlaFfiPythonCpuCallback, XlaFfiPythonCpuCallback,
ffi::Ffi::Bind()
.Ctx<ffi::UserData<
std::vector<tsl::RCReference<ifrt::LoadedHostCallback>>>>()
.Ctx<ffi::UserData<ifrt::FfiLoadedHostCallbacks>>()
.Attr<uint64_t>("index")
.RemainingArgs()
.RemainingRets());
Expand Down
2 changes: 1 addition & 1 deletion xla/python/py_client_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ limitations under the License.
#ifndef XLA_PYTHON_PY_CLIENT_CPU_H_
#define XLA_PYTHON_PY_CLIENT_CPU_H_

#include "xla/ffi/ffi.h"
#include "xla/ffi/api/ffi.h"

namespace xla {

Expand Down
Loading

0 comments on commit e1c8822

Please sign in to comment.