Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Computation Hash] Introduce deterministic hash for user computations #8554

Merged
merged 2 commits into from
Jan 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1166,9 +1166,8 @@ class PyLoweringContext {
// Create a serialized HloModule protobuf from a lowered graph
py::bytes GetHlo() {
const xla::HloModuleProto& proto = computation.proto();
std::string result;
proto.SerializeToString(&result);
return result;
return ConsumeValue(
runtime::util::GetDeterministicSerializedModuleProto(proto));
}

// Create human-readable HloModule protobuf text from a lowered graph
Expand Down
2 changes: 2 additions & 0 deletions torch_xla/csrc/runtime/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ cc_library(
":types",
":util",
":xla_coordinator",
":xla_util",
"//torch_xla/csrc:device",
"//torch_xla/csrc:dtype",
"@com_google_absl//absl/memory",
Expand Down Expand Up @@ -460,6 +461,7 @@ ptxla_cc_test(
size = "small",
srcs = ["xla_util_test.cc"],
deps = [
":debug_macros",
":xla_util",
"@com_google_absl//absl/types:span",
"@com_google_googletest//:gtest_main",
Expand Down
9 changes: 9 additions & 0 deletions torch_xla/csrc/runtime/computation_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "torch_xla/csrc/runtime/env_vars.h"
#include "torch_xla/csrc/runtime/sys_util.h"
#include "torch_xla/csrc/runtime/xla_util.h"
#include "tsl/platform/stacktrace_handler.h"
#include "xla/status_macros.h"

Expand Down Expand Up @@ -194,5 +195,13 @@ metrics::Metric* ComputationClient::OutboundDataMetric() {
return metric;
}

::absl::StatusOr<torch::lazy::hash_t>
ComputationClient::Computation::ComputeHash(const xla::HloModuleProto& proto,
const std::string& name) {
TF_ASSIGN_OR_RETURN(auto serialized_status,
util::GetDeterministicSerializedModuleProto(proto));
return torch::lazy::MHash(name, serialized_status);
}

} // namespace runtime
} // namespace torch_xla
13 changes: 10 additions & 3 deletions torch_xla/csrc/runtime/computation_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,8 +115,8 @@ class ComputationClient {
computation_(std::move(computation)),
devices_(std::move(devices)) {
program_shape_ = ConsumeValue(computation_.GetProgramShape());
hash_ =
torch::lazy::MHash(name, computation_.proto().SerializeAsString());
const xla::HloModuleProto& proto = computation_.proto();
hash_ = ConsumeValue(ComputeHash(proto, name));
}

Computation(std::string name, xla::XlaComputation computation,
Expand Down Expand Up @@ -159,7 +159,7 @@ class ComputationClient {
// here.
xla::XlaComputation move_computation() {
if (computation_moved_) {
XLA_ERROR() << "Compuation has been moved\n";
XLA_ERROR() << "Computation has been moved\n";
}
computation_moved_ = true;
return std::move(const_cast<Computation*>(this)->computation_);
Expand Down Expand Up @@ -206,6 +206,13 @@ class ComputationClient {

torch::lazy::hash_t hash_;
std::string name_;

// Computes a hash for an HLO module using deterministic proto
// serialization. It ensures consistent ordering of Map fields and repeated
// elements during during serialization. The resulting hash combines the
// serialized module with its computation name.
static ::absl::StatusOr<torch::lazy::hash_t> ComputeHash(
const xla::HloModuleProto& proto, const std::string& name);
};

using ComputationPtr = std::shared_ptr<Computation>;
Expand Down
21 changes: 21 additions & 0 deletions torch_xla/csrc/runtime/xla_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "tsl/platform/errors.h"
#include "tsl/platform/stacktrace.h"
#include "xla/shape_util.h"
#include "xla/tsl/lib/strings/proto_serialization.h"
#include "xla/util.h"

namespace torch_xla {
Expand Down Expand Up @@ -115,6 +116,26 @@ torch::lazy::hash_t ShapeHash(const xla::Shape& shape) {
return hash;
}

absl::StatusOr<std::string> GetDeterministicSerializedModuleProto(
const xla::HloModuleProto& hlo_proto) {
const size_t size = hlo_proto.ByteSizeLong();
if (size == 0) {
return std::string();
}
std::string serialized;
// Pre-allocate the string buffer for the serialized result.
serialized.resize(size);

// Perform deterministic serialization ensuring consistent ordering
// of map fields and repeated elements
if (!tsl::SerializeToBufferDeterministic(hlo_proto, serialized.data(),
size)) {
return absl::InvalidArgumentError("Could not serialize module proto");
}

return serialized;
}

} // namespace util
} // namespace runtime
} // namespace torch_xla
6 changes: 6 additions & 0 deletions torch_xla/csrc/runtime/xla_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ void CheckComputationStatus(

torch::lazy::hash_t ShapeHash(const xla::Shape& shape);

// Return the serialized module proto, using deterministic proto serialization.
// It ensures consistent ordering of Map fields and repeated elements during
// serialization.
absl::StatusOr<std::string> GetDeterministicSerializedModuleProto(
const xla::HloModuleProto& hlo_proto);

} // namespace util
} // namespace runtime
} // namespace torch_xla
Expand Down
151 changes: 149 additions & 2 deletions torch_xla/csrc/runtime/xla_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,15 @@
#include <gmock/gmock.h>
#include <gtest/gtest.h>

#include <random>
#include <set>
#include <string>
#include <unordered_map>
#include <vector>

#include "absl/status/status.h"
#include "absl/types/span.h"
#include "torch_xla/csrc/runtime/debug_macros.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/protobuf.h"
#include "tsl/platform/status_matchers.h"
Expand Down Expand Up @@ -46,7 +49,7 @@ absl::StatusOr<MessageType> ParseTextProto(const std::string& text_proto) {
return parsed_proto;
}

TEST(XlaUtilrest, CreateModule) {
TEST(XlaUtilTest, CreateModule) {
TF_ASSERT_OK_AND_ASSIGN(
xla::HloModuleProto hlo_module_proto,
ParseTextProto<xla::HloModuleProto>(
Expand Down Expand Up @@ -102,7 +105,7 @@ TEST(XlaUtilrest, CreateModule) {
EXPECT_EQ((*got)->computation_count(), 1);
}

TEST(XlaUtilrest, XlaToHlo) {
TEST(XlaUtilTest, XlaToHlo) {
xla::Shape input_shape =
xla::ShapeUtil::MakeShape(xla::PrimitiveType::F32, {2, 2});
xla::XlaBuilder builder("AddComputation");
Expand All @@ -116,6 +119,150 @@ TEST(XlaUtilrest, XlaToHlo) {
HasSubstr("ROOT %add.3"))));
}

TEST(XlaUtilTest, TestDeterministicModuleProtoSerializationEmptyProto) {
xla::HloModuleProto empty_proto;
auto result =
::ConsumeValue(GetDeterministicSerializedModuleProto(empty_proto));
// Verify that the result is an empty string
EXPECT_TRUE(result.empty());
}

TEST(XlaUtilTest, TestDeterministicModuleProtoSerialization) {
// Create a test HLO module with a known structure
TF_ASSERT_OK_AND_ASSIGN(
xla::HloModuleProto hlo_module_proto,
ParseTextProto<xla::HloModuleProto>(
R"pb(
name: "myname"
id: 9
entry_computation_name: "MyCustomName.9"
entry_computation_id: 9
computations {
id: 9
name: "MyCustomName.9"
instructions: {
name: "p0.1"
id: 1
opcode: "parameter"
shape: {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
metadata {
op_type: "xla__device_data"
op_name: "xla__device_data"
source_file: "/ansible/pytorch/xla/small_test.py"
source_line: 14
stack_frame_id: 1
}
}
instructions: {
name: "p1.2"
id: 2
opcode: "parameter"
parameter_number: 1
shape: {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
metadata {
op_type: "xla__device_data"
op_name: "xla__device_data"
source_file: "/ansible/pytorch/xla/small_test.py"
source_line: 13
stack_frame_id: 2
}
}
instructions: {
name: "call.7"
id: 7
opcode: "call"
shape: {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
metadata {
op_type: "xla___op_some_op"
op_name: "xla___op_some_op"
source_file: "/ansible/pytorch/xla/torch_xla/core/xla_op_registry.py"
source_line: 44
stack_frame_id: 4
}
called_computation_ids: 3
operand_ids: 2
operand_ids: 1
}
instructions: {
name: "tuple.8"
id: 8
opcode: "tuple"
shape: {
element_type: TUPLE
tuple_shapes {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
}
operand_ids: 7
}
root_id: 8
}
host_program_shape: {
parameters {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
parameters {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
result {
element_type: TUPLE
tuple_shapes {
element_type: S64
layout { tail_padding_alignment_in_elements: 1 }
}
}
parameter_names: "p0"
parameter_names: "p1"
}
)pb"));

// Define a set of dummy fixed key-value pairs for frontend attributes.
std::vector<std::pair<std::string, std::string>> attr_pairs = {
{"key1", "value1"},
{"key2", "value2"},
{"key3", "value3"},
{"key4", "value4"}};

auto shuffle_and_hash = [&attr_pairs](xla::HloModuleProto hlo_module_proto) {
// Create a random number generator for shuffling.
std::random_device random_device;
std::mt19937 random_generator(random_device());

for (auto& computation : *hlo_module_proto.mutable_computations()) {
for (auto& instruction : *computation.mutable_instructions()) {
std::shuffle(attr_pairs.begin(), attr_pairs.end(), random_generator);
auto* frontend_attrs = instruction.mutable_frontend_attributes();
// Add the dummy shuffled pairs to the frontend attributes.
for (const auto& pair : attr_pairs) {
(*frontend_attrs->mutable_map())[pair.first] = pair.second;
}
}
}
std::string serialized_proto =
::ConsumeValue(GetDeterministicSerializedModuleProto(hlo_module_proto));
return torch::lazy::Hash(serialized_proto);
};

// Compute hashes with different random orderings of attributes
torch::lazy::hash_t hash1 = shuffle_and_hash(hlo_module_proto);
torch::lazy::hash_t hash2 = shuffle_and_hash(hlo_module_proto);
// Verify that different orderings produce the same hash
ASSERT_EQ(hash1, hash2)
<< "Hashes should match regardless of the frontend attribute ordering";
}

} // namespace util
} // namespace runtime
} // namespace torch_xla
21 changes: 12 additions & 9 deletions torch_xla/csrc/xla_graph_executor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1200,12 +1200,13 @@ XLAGraphExecutor::LookupCachedCompile(const torch::lazy::hash_t& hash) {
TORCH_LAZY_COUNTER("UncachedCompile", 1);
return nullptr;
}
std::string serialized_computation =
ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto(
cached_computation->computation->computation().proto()));
TF_VLOG(5) << "Graph hash " << torch::lazy::HashToString(hash)
<< " is computation hash "
<< torch::lazy::HashToString(torch::lazy::Hash(
cached_computation->computation->computation()
.proto()
.SerializeAsString()));
<< torch::lazy::HashToString(
torch::lazy::Hash(serialized_computation));
TORCH_LAZY_COUNTER("CachedCompile", 1);
return cached_computation;
}
Expand Down Expand Up @@ -1443,11 +1444,13 @@ XLAGraphExecutor::CompilationResult XLAGraphExecutor::Compile(
<< coll.device << " done!";
TF_VLOG(5) << "Compiled program shape "
<< computations.front()->program_shape().ToString() << std::endl;
TF_VLOG(5)
<< "Graph hash " << torch::lazy::HashToString(coll.hash)
<< " is computation hash "
<< torch::lazy::HashToString(torch::lazy::Hash(
computations.front()->computation().proto().SerializeAsString()));
std::string serialized_computation =
ConsumeValue(runtime::util::GetDeterministicSerializedModuleProto(
computations.front()->computation().proto()));
TF_VLOG(5) << "Graph hash " << torch::lazy::HashToString(coll.hash)
<< " is computation hash "
<< torch::lazy::HashToString(
torch::lazy::Hash(serialized_computation));

if (use_autosharding) {
const xla::HloModuleProto& computation_proto =
Expand Down
Loading