Skip to content

Commit

Permalink
[XLA:GPU] Fix RaggedAllToAllDecomposer when input and output buffers …
Browse files Browse the repository at this point in the history
…have different sizes.

We were doubling the size of the buffer to be able to use dynamic-update-slice, because by HLO semantics, if the update goes out of bound of the result, the update is not applied at all. The correct solution is to pad to `input_size + output_size`.

PiperOrigin-RevId: 732239941
  • Loading branch information
olegshyshkov authored and Google-ML-Automation committed Feb 28, 2025
1 parent ee73110 commit 0c04616
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 21 deletions.
39 changes: 22 additions & 17 deletions xla/service/gpu/transforms/ragged_all_to_all_decomposer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,22 +220,23 @@ HloInstruction* CreateUpdateMask(HloInstruction* offset_value,
less_than_upper_bound));
}

// Pads the outermost dimension of the input tensor to double the size.
// Pads the outermost dimension of the hlo result by the given padding size.
HloInstruction* PadOutermostDimension(HloComputation* computation,
HloInstruction* input) {
Shape padded_shape = input->shape();
HloInstruction* hlo,
int64_t padding_size) {
Shape padded_shape = hlo->shape();

PaddingConfig padding_config = MakeNoPaddingConfig(padded_shape.rank());
padding_config.mutable_dimensions(0)->set_edge_padding_high(
padded_shape.dimensions(0));
padding_config.mutable_dimensions(0)->set_edge_padding_high(padding_size);

padded_shape.set_dimensions(0, 2 * padded_shape.dimensions(0));
padded_shape.set_dimensions(0, padded_shape.dimensions(0) + padding_size);

HloInstruction* padding_value =
computation->AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(input->shape().element_type())));
LiteralUtil::Zero(hlo->shape().element_type())));

return computation->AddInstruction(HloInstruction::CreatePad(
padded_shape, input, padding_value, padding_config));
padded_shape, hlo, padding_value, padding_config));
}

// Returns dense representation of the ragged input tensor.
Expand All @@ -246,7 +247,8 @@ HloInstruction* PadOutermostDimension(HloComputation* computation,
std::vector<HloInstruction*> RaggedToDense(HloComputation* computation,
HloInstruction* ragged_input,
HloInstruction* offsets,
int64_t num_updates_per_replica) {
int64_t num_updates_per_replica,
int64_t max_update_size) {
int64_t num_rows = offsets->shape().dimensions(0);

std::vector<HloInstruction*> result;
Expand All @@ -259,7 +261,7 @@ std::vector<HloInstruction*> RaggedToDense(HloComputation* computation,
ragged_input->shape().rank());

HloInstruction* padded_input =
PadOutermostDimension(computation, ragged_input);
PadOutermostDimension(computation, ragged_input, max_update_size);

HloInstruction* row_slice =
computation->AddInstruction(HloInstruction::CreateDynamicSlice(
Expand All @@ -286,17 +288,19 @@ HloInstruction* DenseToRagged(HloComputation* computation,
HloInstruction* dense_inputs,
HloInstruction* ragged_output,
HloInstruction* offsets, HloInstruction* sizes,
int64_t num_updates_per_replica) {
int64_t num_updates_per_replica,
int64_t max_update_size) {
int64_t num_rows = offsets->shape().dimensions(0);
int64_t rank = ragged_output->shape().rank();

Shape original_shape = ragged_output->shape();

// Pad the outermost dimension of the ragged output to double the size. This
// is needed to be able to insert updates with dynamic-update-slice to the
// ragged output.
// Pad the outermost dimension of the ragged output by dense inputs update
// size. This is needed to be able to insert updates with dynamic-update-slice
// to the ragged output.
HloInstruction* padded_ragged_output =
PadOutermostDimension(computation, ragged_output);
PadOutermostDimension(computation, ragged_output,
/*padding_size=*/max_update_size);

for (int64_t i = 0; i < num_rows / num_updates_per_replica; ++i) {
for (int64_t j = 0; j < num_updates_per_replica; ++j) {
Expand Down Expand Up @@ -380,6 +384,7 @@ absl::StatusOr<bool> DecomposeRaggedAllToAll(HloInstruction* hlo,
int64_t num_total_updates = input_offsets->shape().dimensions(0);
int64_t num_updates_per_replica =
num_total_updates / *num_participating_devices;
int64_t max_update_size = input_operand->shape().dimensions(0);

// Runs all-to-all to exchange output offsets for each participating device.
// RaggedAllToAll API requires that output offsets are calculated from the
Expand All @@ -391,7 +396,7 @@ absl::StatusOr<bool> DecomposeRaggedAllToAll(HloInstruction* hlo,
*num_participating_devices);

auto dense_input = RaggedToDense(computation, input_operand, input_offsets,
num_updates_per_replica);
num_updates_per_replica, max_update_size);

std::vector<Shape> dense_input_shapes;
dense_input_shapes.reserve(dense_input.size());
Expand All @@ -408,7 +413,7 @@ absl::StatusOr<bool> DecomposeRaggedAllToAll(HloInstruction* hlo,

auto* ragged_output =
DenseToRagged(computation, dense_output, output_operand, output_offsets,
recv_sizes, num_updates_per_replica);
recv_sizes, num_updates_per_replica, max_update_size);

TF_RETURN_IF_ERROR(all_to_all->ReplaceAllUsesWith(ragged_output));
TF_RETURN_IF_ERROR(
Expand Down
103 changes: 99 additions & 4 deletions xla/tests/collective_ops_e2e_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1972,15 +1972,20 @@ class RaggedAllToAllTest : public CollectiveOpsWithFlagsBase,
EXPECT_THAT(ragged_all_to_all, NotNull());

// Shape of the ragged input tensor.
std::vector<int64_t> ragged_tensor_sizes{
std::vector<int64_t> input_ragged_tensor_sizes{
ragged_all_to_all->operand(0)->shape().dimensions().begin(),
ragged_all_to_all->operand(0)->shape().dimensions().end()};

// Shape of the ragged output tensor.
std::vector<int64_t> output_ragged_tensor_sizes{
ragged_all_to_all->shape().dimensions().begin(),
ragged_all_to_all->shape().dimensions().end()};

// The ragged-all-to-all accepts an output tensor as a parameter to allow
// buffer reuse. We initialize the output tensor with -1 to make sure that
// we don't accidentally overwrite data that is not part of the
// ragged-all-to-all update.
Array<float> output_init_data(ragged_tensor_sizes);
Array<float> output_init_data(output_ragged_tensor_sizes);
output_init_data.Fill(-1);

Array<IndexType> output_sizes = input_sizes;
Expand All @@ -1991,8 +1996,8 @@ class RaggedAllToAllTest : public CollectiveOpsWithFlagsBase,
output_offsets.TransposeDimensions({1, 0, 2});

int64_t num_replicas = input_sizes.dim(0);
std::vector<Array<float>> input_data(num_replicas,
Array<float>(ragged_tensor_sizes));
std::vector<Array<float>> input_data(
num_replicas, Array<float>(input_ragged_tensor_sizes));
std::vector<Array<float>> output_data(num_replicas, output_init_data);
FillWithRandomData(input_data, output_data, input_offsets, output_offsets,
input_sizes);
Expand Down Expand Up @@ -2168,6 +2173,96 @@ XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_2GPUs) {
EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[1], results[1]));
}

XLA_TEST_P(RaggedAllToAllTest,
RaggedAllToAll_2GPUs_InputBufferLargerThanOutput) {
absl::string_view kModuleReplicatedStr = R"(
HloModule module, num_partitions=1
ENTRY entry {
input = f32[32] parameter(0)
output = f32[16] parameter(1)
input_offsets = s32[2] parameter(2)
send_sizes = s32[2] parameter(3)
output_offsets = s32[2] parameter(4)
recv_sizes = s32[2] parameter(5)
ROOT ra2a = f32[16] ragged-all-to-all(input, output, input_offsets,
send_sizes, output_offsets, recv_sizes), replica_groups={{0,1}}
})";

const int64_t kNumReplicas = 2;
const int64_t kNumPartitions = 1;
if (test_runner().device_count() < kNumReplicas * kNumPartitions) {
GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions
<< " devices (" << test_runner().device_count()
<< " available)";
}

HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas * kNumPartitions);

TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config));

CreateRandomTestData</*IndexType=*/int32_t>(
module.get(), /*input_sizes=*/{/*replica_0=*/{8, 5},
/*replica_1=*/{4, 3}});

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
HloTestBase::ExecuteReplicated(std::move(module), GetInputLiteralPtrs(),
/*num_replicas=*/kNumReplicas,
/*run_hlo_passes=*/true,
/*device_assignment=*/nullptr));
ASSERT_EQ(results.size(), kNumReplicas);
EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[0], results[0]));
EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[1], results[1]));
}

XLA_TEST_P(RaggedAllToAllTest,
RaggedAllToAll_2GPUs_OutputBufferLargerThanInput) {
absl::string_view kModuleReplicatedStr = R"(
HloModule module, num_partitions=1
ENTRY entry {
input = f32[16] parameter(0)
output = f32[32] parameter(1)
input_offsets = s32[2] parameter(2)
send_sizes = s32[2] parameter(3)
output_offsets = s32[2] parameter(4)
recv_sizes = s32[2] parameter(5)
ROOT ra2a = f32[32] ragged-all-to-all(input, output, input_offsets,
send_sizes, output_offsets, recv_sizes), replica_groups={{0,1}}
})";

const int64_t kNumReplicas = 2;
const int64_t kNumPartitions = 1;
if (test_runner().device_count() < kNumReplicas * kNumPartitions) {
GTEST_SKIP() << "Test requires at least " << kNumReplicas * kNumPartitions
<< " devices (" << test_runner().device_count()
<< " available)";
}

HloModuleConfig config =
GetModuleConfigForTest(/*replica_count=*/kNumReplicas * kNumPartitions);

TF_ASSERT_OK_AND_ASSIGN(
auto module, ParseAndReturnVerifiedModule(kModuleReplicatedStr, config));

CreateRandomTestData</*IndexType=*/int32_t>(
module.get(), /*input_sizes=*/{/*replica_0=*/{4, 12},
/*replica_1=*/{5, 11}});

TF_ASSERT_OK_AND_ASSIGN(
std::vector<Literal> results,
HloTestBase::ExecuteReplicated(std::move(module), GetInputLiteralPtrs(),
/*num_replicas=*/kNumReplicas,
/*run_hlo_passes=*/true,
/*device_assignment=*/nullptr));
ASSERT_EQ(results.size(), kNumReplicas);
EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[0], results[0]));
EXPECT_TRUE(LiteralTestUtil::Equal(expected_outputs_[1], results[1]));
}

XLA_TEST_P(RaggedAllToAllTest, RaggedAllToAll_2GPUs_MultipleUpdates) {
absl::string_view kModuleReplicatedStr = R"(
HloModule module, num_partitions=1
Expand Down

0 comments on commit 0c04616

Please sign in to comment.