Skip to content

Commit

Permalink
Add Deserialize function to HloRunnerInterface.
Browse files Browse the repository at this point in the history
This function consumes a runner-specific protobuf message, which it uses to construct a
runner-internal `OpaqueExecutable`. Where available, this function delegates to pre-existing
deserialization functionality, such as what exists in the PjRt API.

PiperOrigin-RevId: 731886192
  • Loading branch information
nvgrw authored and Google-ML-Automation committed Mar 1, 2025
1 parent 397b651 commit ea851b4
Show file tree
Hide file tree
Showing 7 changed files with 63 additions and 0 deletions.
4 changes: 4 additions & 0 deletions xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1591,6 +1591,7 @@ cc_library(
"//xla/pjrt/distributed:key_value_store_interface",
"//xla/stream_executor:dnn",
"//xla/stream_executor:stream_executor_h",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
Expand Down Expand Up @@ -4758,6 +4759,7 @@ cc_library(
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:protobuf",
],
)

Expand Down Expand Up @@ -4804,6 +4806,7 @@ cc_library(
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@eigen_archive//:eigen3",
"@tsl//tsl/platform:protobuf",
],
)

Expand Down Expand Up @@ -4844,6 +4847,7 @@ cc_library(
"@com_google_absl//absl/strings:string_view",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/types:span",
"@tsl//tsl/platform:protobuf",
],
)

Expand Down
6 changes: 6 additions & 0 deletions xla/service/compiler.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ limitations under the License.
#include <utility>
#include <vector>

#include "absl/base/nullability.h"
#include "absl/container/flat_hash_map.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -322,6 +323,11 @@ class Compiler {
virtual std::unique_ptr<MetricsHookInterface> CreateMetricsHook(
absl::string_view filename_prefix) const;

virtual absl::StatusOr<std::unique_ptr<Executable>> DeserializeExecutable(
absl::Nonnull<const proto2::Message*> serialized) const {
return Unimplemented("DeserializeExecutable unimplemented");
}

private:
// Mutex that guards the platform-compiler map.
static absl::Mutex platform_compiler_mutex_;
Expand Down
9 changes: 9 additions & 0 deletions xla/service/hlo_runner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ limitations under the License.
#include "xla/tsl/platform/status.h"
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/threadpool.h"
#include "tsl/platform/protobuf.h"

namespace xla {

Expand Down Expand Up @@ -753,6 +754,14 @@ HloRunner::CreateExecutableWithBufferAssignment(
return std::make_unique<HloRunnerExecutable>(this, std::move(executable));
}

absl::StatusOr<std::unique_ptr<OpaqueExecutable>>
HloRunner::DeserializeExecutable(
absl::Nonnull<const proto2::Message*> serialized) const {
TF_ASSIGN_OR_RETURN(std::unique_ptr<Executable> executable,
backend().compiler()->DeserializeExecutable(serialized));
return std::make_unique<HloRunnerExecutable>(this, std::move(executable));
}

ServiceExecutableRunOptions HloRunner::GetServiceRunOptionsForDevice(
int64_t device, se::Stream* stream, DeviceAssignment* device_assignment,
RunId run_id, int local_device_count) {
Expand Down
8 changes: 8 additions & 0 deletions xla/service/hlo_runner.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ limitations under the License.
#include "xla/stream_executor/platform.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/protobuf.h"

namespace xla {

Expand Down Expand Up @@ -148,6 +149,13 @@ class HloRunner : public HloRunnerInterface {
const BufferAssignmentProto* /*buffer_assignment_proto*/,
bool run_hlo_passes) override;

// Creates a runner-internal executable object given a runner and
// platform-specific serialized executable representation. The serialized
// representation must have been produced by a compiler of the same platform
// and version as this one.
absl::StatusOr<std::unique_ptr<OpaqueExecutable>> DeserializeExecutable(
absl::Nonnull<const proto2::Message*> serialized) const override;

// Executes a given HLO module into a set of replicas, and returns a map
// with the replica number as key, and the corresponding returned literal as
// value.
Expand Down
9 changes: 9 additions & 0 deletions xla/service/hlo_runner_interface.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/util.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/protobuf.h"

namespace xla {

Expand Down Expand Up @@ -224,6 +225,14 @@ class HloRunnerInterface {
virtual absl::StatusOr<std::unique_ptr<OpaqueExecutable>> CreateExecutable(
std::unique_ptr<HloModule> module, bool run_hlo_passes) = 0;

// Creates a runner-internal executable object given a runner and
// platform-specific serialized executable representation. The serialized
// representation must have been produced by a compiler of the same platform
// and version as this one.
virtual absl::StatusOr<std::unique_ptr<OpaqueExecutable>>
DeserializeExecutable(
absl::Nonnull<const proto2::Message*> serialized) const = 0;

// Same as above, except it takes buffer assignment as input.
// Note: The default implementation of the API here does not utilize the given
// buffer assignment. A derived runner interface is expected to override the
Expand Down
19 changes: 19 additions & 0 deletions xla/service/hlo_runner_pjrt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <functional>
#include <memory>
#include <optional>
#include <string>
#include <utility>
#include <vector>

Expand Down Expand Up @@ -54,6 +55,7 @@ limitations under the License.
#include "xla/tsl/platform/statusor.h"
#include "xla/tsl/platform/threadpool.h"
#include "xla/util.h"
#include "tsl/platform/protobuf.h"

namespace xla {

Expand Down Expand Up @@ -499,6 +501,23 @@ HloRunnerPjRt::CreateExecutable(std::unique_ptr<HloModule> module,
std::move(pjrt_executable));
}

absl::StatusOr<std::unique_ptr<OpaqueExecutable>>
HloRunnerPjRt::DeserializeExecutable(
absl::Nonnull<const proto2::Message*> serialized) const {
absl::Cord serialized_cord;
serialized->SerializeToString(&serialized_cord);

// TODO: b/237720161 - According to the comment in the base class, the
// `options` argument is mandatory. However, our implementation is capable of
// handling the default case where it is not present. The options are
// serialized with the executable and we can read them from there.
// Remove this comment once the bug is closed.
TF_ASSIGN_OR_RETURN(std::unique_ptr<PjRtLoadedExecutable> executable,
pjrt_client_->DeserializeExecutable(
serialized_cord.Flatten(), /*options=*/std::nullopt));
return std::make_unique<HloRunnerPjRtExecutable>(this, std::move(executable));
}

absl::StatusOr<std::vector<Literal>> HloRunnerPjRt::ExecuteReplicated(
std::unique_ptr<HloModule> module,
const HloRunnerInterface::ReplicatedExecuteOptions& options) {
Expand Down
8 changes: 8 additions & 0 deletions xla/service/hlo_runner_pjrt.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ limitations under the License.
#include "xla/service/hlo_runner_interface.h"
#include "xla/shape_layout.h"
#include "xla/xla_data.pb.h"
#include "tsl/platform/protobuf.h"

namespace xla {

Expand Down Expand Up @@ -96,6 +97,13 @@ class HloRunnerPjRt : public HloRunnerInterface {
absl::StatusOr<std::unique_ptr<OpaqueExecutable>> CreateExecutable(
std::unique_ptr<HloModule> module, bool run_hlo_passes) override;

// Creates a runner-internal executable object given a runner and
// platform-specific serialized executable representation. The serialized
// representation must have been produced by a compiler of the same platform
// and version as this one.
absl::StatusOr<std::unique_ptr<OpaqueExecutable>> DeserializeExecutable(
absl::Nonnull<const proto2::Message*> serialized) const override;

absl::StatusOr<Literal> ExecuteWithExecutable(
OpaqueExecutable* executable, absl::Span<const Literal* const> arguments,
ExecutionProfile* profile) override;
Expand Down

0 comments on commit ea851b4

Please sign in to comment.