Skip to content

Commit

Permalink
[Computation Hash] Introduce deterministic hash for user computations (
Browse files Browse the repository at this point in the history
  • Loading branch information
rpsilva-aws authored Jan 11, 2025
1 parent 28b9b0f commit 5ce8609
Show file tree
Hide file tree
Showing 8 changed files with 211 additions and 17 deletions.
5 changes: 2 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,9 +1166,8 @@ class PyLoweringContext {
// Create a serialized HloModule protobuf from a lowered graph
py::bytes GetHlo() {
const xla::HloModuleProto& proto = computation.proto();
std::string result;
proto.SerializeToString(&result);
return result;
return ConsumeValue(
runtime::util::GetDeterministicSerializedModuleProto(proto));
}

// Create human-readable HloModule protobuf text from a lowered graph
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ cc_library(
":types",
":util",
":xla_coordinator",
":xla_util",
"//torch_xla/csrc:device",
"//torch_xla/csrc:dtype",
"@com_google_absl//absl/memory",
Expand Down Expand Up @@ -460,6 +461,7 @@ ptxla_cc_test(
size = "small",
srcs = ["xla_util_test.cc"],
deps = [
":debug_macros",
":xla_util",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/runtime/computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "tsl/platform/stacktrace_handler.h"
#include "xla/status_macros.h"

Expand Down Expand Up @@ -194,5 +195,13 @@ metrics::Metric* ComputationClient::OutboundDataMetric() {
return metric;
}

::absl::StatusOr<torch::lazy::hash_t>
ComputationClient::Computation::ComputeHash(const xla::HloModuleProto& proto,
const std::string& name) {
TF_ASSIGN_OR_RETURN(auto serialized_status,
util::GetDeterministicSerializedModuleProto(proto));
return torch::lazy::MHash(name, serialized_status);
}

} // namespace runtime
} // namespace torch_xla
13 changes: 10 additions & 3 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ class ComputationClient {
computation_(std::move(computation)),
devices_(std::move(devices)) {
program_shape_ = ConsumeValue(computation_.GetProgramShape());
hash_ =
torch::lazy::MHash(name, computation_.proto().SerializeAsString());
const xla::HloModuleProto& proto = computation_.proto();
hash_ = ConsumeValue(ComputeHash(proto, name));
}

Computation(std::string name, xla::XlaComputation computation,
Expand Down Expand Up @@ -159,7 +159,7 @@ class ComputationClient {
// here.
xla::XlaComputation move_computation() {
if (computation_moved_) {
XLA_ERROR() << "Compuation has been moved\n";
XLA_ERROR() << "Computation has been moved\n";
}
computation_moved_ = true;
return std::move(const_cast<Computation*>(this)->computation_);
Expand Down Expand Up @@ -206,6 +206,13 @@ class ComputationClient {

torch::lazy::hash_t hash_;
std::string name_;

// Computes a hash for an HLO module using deterministic proto
// serialization. It ensures consistent ordering of Map fields and repeated
// elements during during serialization. The resulting hash combines the
// serialized module with its computation name.
static ::absl::StatusOr<torch::lazy::hash_t> ComputeHash(
const xla::HloModuleProto& proto, const std::string& name);
};

using ComputationPtr = std::shared_ptr<Computation>;
Expand Down
21 changes: 21 additions & 0 deletions torch_xla/csrc/runtime/xla_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "tsl/platform/errors.h"
#include "tsl/platform/stacktrace.h"
#include "xla/shape_util.h"
#include "xla/tsl/lib/strings/proto_serialization.h"
#include "xla/util.h"

namespace torch_xla {
Expand Down Expand Up @@ -115,6 +116,26 @@ torch::lazy::hash_t ShapeHash(const xla::Shape& shape) {
return hash;
}

absl::StatusOr<std::string> GetDeterministicSerializedModuleProto(
const xla::HloModuleProto& hlo_proto) {
const size_t size = hlo_proto.ByteSizeLong();
if (size == 0) {
return std::string();
}
std::string serialized;
// Pre-allocate the string buffer for the serialized result.
serialized.resize(size);

// Perform deterministic serialization ensuring consistent ordering
// of map fields and repeated elements
if (!tsl::SerializeToBufferDeterministic(hlo_proto, serialized.data(),
size)) {
return absl::InvalidArgumentError("Could not serialize module proto");
}

return serialized;
}

} // namespace util
} // namespace runtime
} // namespace torch_xla
6 changes: 6 additions & 0 deletions torch_xla/csrc/runtime/xla_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ void CheckComputationStatus(

torch::lazy::hash_t ShapeHash(const xla::Shape& shape);

// Return the serialized module proto, using deterministic proto serialization.
// It ensures consistent ordering of Map fields and repeated elements during
// serialization.
absl::StatusOr<std::string> GetDeterministicSerializedModuleProto(
const xla::HloModuleProto& hlo_proto);

} // namespace util
} // namespace runtime
} // namespace torch_xla
Expand Down
151 changes: 149 additions & 2 deletions torch_xla/csrc/runtime/xla_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <random>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>

#include "absl/status/status.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/status_matchers.h"
Expand Down Expand Up @@ -46,7 +49,7 @@ absl::StatusOr<MessageType> ParseTextProto(const std::string& text_proto) {
return parsed_proto;
}

TEST(XlaUtilrest, CreateModule) {
TEST(XlaUtilTest, CreateModule) {
TF_ASSERT_OK_AND_ASSIGN(
xla::HloModuleProto hlo_module_proto,
ParseTextProto<xla::HloModuleProto>(
Expand Down Expand Up @@ -102,7 +105,7 @@ TEST(XlaUtilrest, CreateModule) {
EXPECT_EQ((*got)->computation_count(), 1);
}

TEST(XlaUtilrest, XlaToHlo) {
TEST(XlaUtilTest, XlaToHlo) {
xla::Shape input_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2});
xla::XlaBuilder builder("AddComputation");
Expand All @@ -116,6 +119,150 @@ TEST(XlaUtilrest, XlaToHlo) {
HasSubstr("ROOT %add.3"))));
}

TEST(XlaUtilTest, TestDeterministicModuleProtoSerializationEmptyProto) {
xla::HloModuleProto empty_proto;
auto result =
::ConsumeValue(GetDeterministicSerializedModuleProto(empty_proto));
// Verify that the result is an empty string
EXPECT_TRUE(result.empty());
}

TEST(XlaUtilTest, TestDeterministicModuleProtoSerialization) {
// Create a test HLO module with a known structure
TF_ASSERT_OK_AND_ASSIGN(
xla::HloModuleProto hlo_module_proto,
ParseTextProto<xla::HloModuleProto>(
R"pb(
name: "myname"
id: 9
entry_computation_name: "MyCustomName.9"
entry_computation_id: 9
computations {
id: 9
name: "MyCustomName.9"
instructions: {
name: "p0.1"
id: 1
opcode: "parameter"
shape: {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
metadata {
op_type: "xla__device_data"
op_name: "xla__device_data"
source_file: "/ansible/pytorch/xla/small_test.py"
source_line: 14
stack_frame_id: 1
}
}
instructions: {
name: "p1.2"
id: 2
opcode: "parameter"
parameter_number: 1
shape: {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
metadata {
op_type: "xla__device_data"
op_name: "xla__device_data"
source_file: "/ansible/pytorch/xla/small_test.py"
source_line: 13
stack_frame_id: 2
}
}
instructions: {
name: "call.7"
id: 7
opcode: "call"
shape: {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
metadata {
op_type: "xla___op_some_op"
op_name: "xla___op_some_op"
source_file: "/ansible/pytorch/xla/torch_xla/core/xla_op_registry.py"
source_line: 44
stack_frame_id: 4
}
called_computation_ids: 3
operand_ids: 2
operand_ids: 1
}
instructions: {
name: "tuple.8"
id: 8
opcode: "tuple"
shape: {
element_type: TUPLE
tuple_shapes {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
}
operand_ids: 7
}
root_id: 8
}
host_program_shape: {
parameters {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
parameters {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
result {
element_type: TUPLE
tuple_shapes {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
}
parameter_names: "p0"
parameter_names: "p1"
}
)pb"));

// Define a set of dummy fixed key-value pairs for frontend attributes.
std::vector<std::pair<std::string, std::string>> attr_pairs = {
{"key1", "value1"},
{"key2", "value2"},
{"key3", "value3"},
{"key4", "value4"}};

auto shuffle_and_hash = [&attr_pairs](xla::HloModuleProto hlo_module_proto) {
// Create a random number generator for shuffling.
std::random_device random_device;
std::mt19937 random_generator(random_device());

for (auto& computation : *hlo_module_proto.mutable_computations()) {
for (auto& instruction : *computation.mutable_instructions()) {
std::shuffle(attr_pairs.begin(), attr_pairs.end(), random_generator);
auto* frontend_attrs = instruction.mutable_frontend_attributes();
// Add the dummy shuffled pairs to the frontend attributes.
for (const auto& pair : attr_pairs) {
(*frontend_attrs->mutable_map())[pair.first] = pair.second;
}
}
}
std::string serialized_proto =
::ConsumeValue(GetDeterministicSerializedModuleProto(hlo_module_proto));
return torch::lazy::Hash(serialized_proto);
};

// Compute hashes with different random orderings of attributes
torch::lazy::hash_t hash1 = shuffle_and_hash(hlo_module_proto);
torch::lazy::hash_t hash2 = shuffle_and_hash(hlo_module_proto);
// Verify that different orderings produce the same hash
ASSERT_EQ(hash1, hash2)
<< "Hashes should match regardless of the frontend attribute ordering";
}

} // namespace util
} // namespace runtime
} // namespace torch_xla
21 changes: 12 additions & 9 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1200,12 +1200,13 @@ XLAGraphExecutor::LookupCachedCompile(const torch::lazy::hash_t& hash) {
TORCH_LAZY_COUNTER("UncachedCompile", 1);
return nullptr;
}
std::string serialized_computation =
ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto(
cached_computation->computation->computation().proto()));
TF_VLOG(5) << "Graph hash " << torch::lazy::HashToString(hash)
<< " is computation hash "
<< torch::lazy::HashToString(torch::lazy::Hash(
cached_computation->computation->computation()
.proto()
.SerializeAsString()));
<< torch::lazy::HashToString(
torch::lazy::Hash(serialized_computation));
TORCH_LAZY_COUNTER("CachedCompile", 1);
return cached_computation;
}
Expand Down Expand Up @@ -1443,11 +1444,13 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
<< coll.device << " done!";
TF_VLOG(5) << "Compiled program shape "
<< computations.front()->program_shape().ToString() << std::endl;
TF_VLOG(5)
<< "Graph hash " << torch::lazy::HashToString(coll.hash)
<< " is computation hash "
<< torch::lazy::HashToString(torch::lazy::Hash(
computations.front()->computation().proto().SerializeAsString()));
std::string serialized_computation =
ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto(
computations.front()->computation().proto()));
TF_VLOG(5) << "Graph hash " << torch::lazy::HashToString(coll.hash)
<< " is computation hash "
<< torch::lazy::HashToString(
torch::lazy::Hash(serialized_computation));

if (use_autosharding) {
const xla::HloModuleProto& computation_proto =
Expand Down

0 comments on commit 5ce8609

Please sign in to comment.