Skip to content

Commit

Permalink
[XLA:GPU] Create cuda-specific api for the runtime to populate the te…
Browse files Browse the repository at this point in the history
…nsor map parameter. See child cl for how this is called.

PiperOrigin-RevId: 715377639
  • Loading branch information
vwbaker authored and Google-ML-Automation committed Jan 14, 2025
1 parent f80d088 commit 2940811
Show file tree
Hide file tree
Showing 8 changed files with 312 additions and 0 deletions.
1 change: 1 addition & 0 deletions xla/stream_executor/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,7 @@ cc_library(
":module_spec",
":platform",
":stream",
"//xla/stream_executor/gpu:tma_metadata",
"@com_google_absl//absl/log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
Expand Down
38 changes: 38 additions & 0 deletions xla/stream_executor/cuda/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1012,6 +1012,7 @@ cc_library(
":cuda_stream",
":cuda_timer",
":cuda_version_parser",
":tma_util",
"//xla/stream_executor:activate_context",
"//xla/stream_executor:blas",
"//xla/stream_executor:command_buffer",
Expand All @@ -1036,8 +1037,11 @@ cc_library(
"//xla/stream_executor/gpu:gpu_executor_header",
"//xla/stream_executor/gpu:read_numa_node",
"//xla/stream_executor/gpu:scoped_activate_context",
"//xla/stream_executor/gpu:tma_metadata",
"//xla/tsl/cuda", # buildcleaner: keep
"//xla/tsl/cuda:cudart", # buildcleaner: keep
"//xla/tsl/platform:errors",
"//xla/tsl/platform:statusor",
"@com_google_absl//absl/algorithm:container",
"@com_google_absl//absl/base",
"@com_google_absl//absl/base:core_headers",
Expand Down Expand Up @@ -1864,3 +1868,37 @@ xla_cc_test(
"@tsl//tsl/platform:test",
],
)

cc_library(
name = "tma_util",
srcs = ["tma_util.cc"],
hdrs = ["tma_util.h"],
tags = [
"cuda-only",
"gpu",
],
deps = [
"//xla/stream_executor/gpu:tma_metadata",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings:str_format",
"@local_config_cuda//cuda:cuda_headers",
],
)

cc_test(
name = "tma_util_test",
srcs = ["tma_util_test.cc"],
tags = [
"cuda-only",
"gpu",
],
deps = [
":tma_util",
"//xla/stream_executor/gpu:tma_metadata",
"//xla/tsl/platform:status_matchers",
"@com_google_absl//absl/status",
"@com_google_googletest//:gtest_main",
"@local_config_cuda//cuda:cuda_headers",
],
)
36 changes: 36 additions & 0 deletions xla/stream_executor/cuda/cuda_executor.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ limitations under the License.
#include "xla/stream_executor/cuda/cuda_stream.h"
#include "xla/stream_executor/cuda/cuda_timer.h"
#include "xla/stream_executor/cuda/cuda_version_parser.h"
#include "xla/stream_executor/cuda/tma_util.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/dnn.h"
Expand All @@ -67,6 +68,7 @@ limitations under the License.
#include "xla/stream_executor/gpu/context.h"
#include "xla/stream_executor/gpu/read_numa_node.h"
#include "xla/stream_executor/gpu/scoped_activate_context.h"
#include "xla/stream_executor/gpu/tma_metadata.h"
#include "xla/stream_executor/host_memory_allocation.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/kernel_spec.h"
Expand All @@ -78,6 +80,8 @@ limitations under the License.
#include "xla/stream_executor/semantic_version.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/tsl/platform/errors.h"
#include "xla/tsl/platform/statusor.h"
#include "tsl/platform/casts.h"
#include "tsl/platform/env.h"
#include "tsl/platform/errors.h"
Expand Down Expand Up @@ -1333,5 +1337,37 @@ absl::StatusOr<const CudaKernel*> CudaExecutor::GetCudaKernel(
}
return static_cast<const CudaKernel*>(*it);
}

absl::StatusOr<DeviceMemoryBase> CudaExecutor::CreateTensorMap(
TmaDescriptor tma_desc, void* global_address) {
TF_ASSIGN_OR_RETURN(CUtensorMapDataType data_type,
GetTensorMapDataType(tma_desc.element_size()));
CUtensorMapSwizzle swizzle = GetTensorMapSwizzle(tma_desc.swizzle());
CUtensorMapL2promotion l2_promotion =
GetTensorMapL2Promotion(tma_desc.l2_promotion());
CUtensorMapFloatOOBfill float_oob_fill =
GetTensorMapFloatOOBFill(tma_desc.float_oob_fill());
CUtensorMapInterleave interleave =
GetTensorMapInterleave(tma_desc.interleave());

CUtensorMap tensor_map;
auto result = cuTensorMapEncodeTiled(
&tensor_map, data_type, tma_desc.rank(), global_address,
&tma_desc.global_dims()[0], &tma_desc.global_strides()[0],
&tma_desc.box_dims()[0], &tma_desc.element_strides()[0], interleave,
swizzle, l2_promotion, float_oob_fill);
if (result != CUDA_SUCCESS) {
const char* error_message;
cuGetErrorString(result, &error_message);
return absl::InternalError(absl::StrFormat(
"Failed to create tensormap with cuTensorMapEncodeTiled: %s",
error_message));
}
DeviceMemoryBase device_tensor_map = Allocate(sizeof(tensor_map), 0);
TF_RETURN_IF_ERROR(
SynchronousMemcpy(&device_tensor_map, &tensor_map, sizeof(tensor_map)));
return device_tensor_map;
}

} // namespace gpu
} // namespace stream_executor
7 changes: 7 additions & 0 deletions xla/stream_executor/cuda/cuda_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ limitations under the License.
#include "xla/stream_executor/event_based_timer.h"
#include "xla/stream_executor/fft.h"
#include "xla/stream_executor/gpu/gpu_executor.h"
#include "xla/stream_executor/gpu/tma_metadata.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/kernel_spec.h"
#include "xla/stream_executor/memory_allocation.h"
Expand Down Expand Up @@ -141,6 +142,12 @@ class CudaExecutor : public GpuExecutor {
// associated with this executor. Otherwise a NotFound error is returned.
absl::StatusOr<const CudaKernel*> GetCudaKernel(const Kernel* kernel);

// Creates, allocates, and copies a CUtensorMap object for the given TMA
// descriptor. Returns a DeviceMemoryBase pointing to the allocated
// CUtensorMap object to be used as an argument to a kernel.
absl::StatusOr<DeviceMemoryBase> CreateTensorMap(
TmaDescriptor tma_desc, void* global_address) override;

private:
// Loads a module in cubin format.
absl::StatusOr<ModuleHandle> LoadModuleFromCuBin(const char* cubin)
Expand Down
91 changes: 91 additions & 0 deletions xla/stream_executor/cuda/tma_util.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/tma_util.h"

#include "absl/status/status.h"
#include "absl/status/statusor.h"
#include "absl/strings/str_format.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/gpu/tma_metadata.h"

namespace stream_executor::gpu {

absl::StatusOr<CUtensorMapDataType> GetTensorMapDataType(int element_size) {
switch (element_size) {
case 1:
return CU_TENSOR_MAP_DATA_TYPE_UINT8;
case 2:
return CU_TENSOR_MAP_DATA_TYPE_UINT16;
case 4:
return CU_TENSOR_MAP_DATA_TYPE_UINT32;
case 8:
return CU_TENSOR_MAP_DATA_TYPE_UINT64;
default:
return absl::InvalidArgumentError(
absl::StrFormat("unsupported element size: %d", element_size));
}
}

CUtensorMapSwizzle GetTensorMapSwizzle(TmaDescriptor::TmaSwizzle swizzle) {
switch (swizzle) {
case TmaDescriptor::TmaSwizzle::kNone:
return CU_TENSOR_MAP_SWIZZLE_NONE;
case TmaDescriptor::TmaSwizzle::k32B:
return CU_TENSOR_MAP_SWIZZLE_32B;
case TmaDescriptor::TmaSwizzle::k64B:
return CU_TENSOR_MAP_SWIZZLE_64B;
case TmaDescriptor::TmaSwizzle::k128B:
return CU_TENSOR_MAP_SWIZZLE_128B;
}
}

CUtensorMapL2promotion GetTensorMapL2Promotion(
TmaDescriptor::TmaL2Promotion l2_promotion) {
switch (l2_promotion) {
case TmaDescriptor::TmaL2Promotion::kNone:
return CU_TENSOR_MAP_L2_PROMOTION_NONE;
case TmaDescriptor::TmaL2Promotion::k64B:
return CU_TENSOR_MAP_L2_PROMOTION_L2_64B;
case TmaDescriptor::TmaL2Promotion::k128B:
return CU_TENSOR_MAP_L2_PROMOTION_L2_128B;
case TmaDescriptor::TmaL2Promotion::k256B:
return CU_TENSOR_MAP_L2_PROMOTION_L2_256B;
}
}

CUtensorMapFloatOOBfill GetTensorMapFloatOOBFill(
TmaDescriptor::TmaFloatOobFill oob_fill) {
switch (oob_fill) {
case TmaDescriptor::TmaFloatOobFill::kNone:
return CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE;
case TmaDescriptor::TmaFloatOobFill::kNanRequestZeroFma:
return CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA;
}
}

CUtensorMapInterleave GetTensorMapInterleave(
TmaDescriptor::TmaInterleave interleave) {
switch (interleave) {
case TmaDescriptor::TmaInterleave::kNone:
return CU_TENSOR_MAP_INTERLEAVE_NONE;
case TmaDescriptor::TmaInterleave::k16B:
return CU_TENSOR_MAP_INTERLEAVE_16B;
case TmaDescriptor::TmaInterleave::k32B:
return CU_TENSOR_MAP_INTERLEAVE_32B;
}
}

} // namespace stream_executor::gpu
40 changes: 40 additions & 0 deletions xla/stream_executor/cuda/tma_util.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#ifndef XLA_STREAM_EXECUTOR_CUDA_TMA_UTIL_H_
#define XLA_STREAM_EXECUTOR_CUDA_TMA_UTIL_H_

#include "absl/status/statusor.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/gpu/tma_metadata.h"

namespace stream_executor::gpu {

absl::StatusOr<CUtensorMapDataType> GetTensorMapDataType(int element_size);

CUtensorMapSwizzle GetTensorMapSwizzle(TmaDescriptor::TmaSwizzle swizzle);

CUtensorMapL2promotion GetTensorMapL2Promotion(
TmaDescriptor::TmaL2Promotion l2_promotion);

CUtensorMapFloatOOBfill GetTensorMapFloatOOBFill(
TmaDescriptor::TmaFloatOobFill oob_fill);

CUtensorMapInterleave GetTensorMapInterleave(
TmaDescriptor::TmaInterleave interleave);

} // namespace stream_executor::gpu

#endif // XLA_STREAM_EXECUTOR_CUDA_TMA_UTIL_H_
89 changes: 89 additions & 0 deletions xla/stream_executor/cuda/tma_util_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
/* Copyright 2025 The OpenXLA Authors.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/stream_executor/cuda/tma_util.h"

#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "absl/status/status.h"
#include "third_party/gpus/cuda/include/cuda.h"
#include "xla/stream_executor/gpu/tma_metadata.h"
#include "xla/tsl/platform/status_matchers.h"

namespace stream_executor::gpu {
namespace {

using ::tsl::testing::IsOkAndHolds;
using ::tsl::testing::StatusIs;

TEST(TmaUtilTest, GetTensorMapDataTypeReturnsCorrectDataType) {
EXPECT_THAT(GetTensorMapDataType(1),
IsOkAndHolds(CU_TENSOR_MAP_DATA_TYPE_UINT8));
EXPECT_THAT(GetTensorMapDataType(2),
IsOkAndHolds(CU_TENSOR_MAP_DATA_TYPE_UINT16));
EXPECT_THAT(GetTensorMapDataType(4),
IsOkAndHolds(CU_TENSOR_MAP_DATA_TYPE_UINT32));
EXPECT_THAT(GetTensorMapDataType(8),
IsOkAndHolds(CU_TENSOR_MAP_DATA_TYPE_UINT64));
}

TEST(TmaUtilTest, GetTensorMapDataTypeFailsGracefully) {
EXPECT_THAT(GetTensorMapDataType(0),
StatusIs(absl::StatusCode::kInvalidArgument));
EXPECT_THAT(GetTensorMapDataType(16),
StatusIs(absl::StatusCode::kInvalidArgument));
}

TEST(TmaUtilTest, GetTensorMapSwizzleReturnsCorrectSwizzle) {
EXPECT_EQ(GetTensorMapSwizzle(TmaDescriptor::TmaSwizzle::kNone),
CU_TENSOR_MAP_SWIZZLE_NONE);
EXPECT_EQ(GetTensorMapSwizzle(TmaDescriptor::TmaSwizzle::k32B),
CU_TENSOR_MAP_SWIZZLE_32B);
EXPECT_EQ(GetTensorMapSwizzle(TmaDescriptor::TmaSwizzle::k64B),
CU_TENSOR_MAP_SWIZZLE_64B);
EXPECT_EQ(GetTensorMapSwizzle(TmaDescriptor::TmaSwizzle::k128B),
CU_TENSOR_MAP_SWIZZLE_128B);
}

TEST(TmaUtilTest, GetTensorMapL2PromotionReturnsCorrectL2Promotion) {
EXPECT_EQ(GetTensorMapL2Promotion(TmaDescriptor::TmaL2Promotion::kNone),
CU_TENSOR_MAP_L2_PROMOTION_NONE);
EXPECT_EQ(GetTensorMapL2Promotion(TmaDescriptor::TmaL2Promotion::k64B),
CU_TENSOR_MAP_L2_PROMOTION_L2_64B);
EXPECT_EQ(GetTensorMapL2Promotion(TmaDescriptor::TmaL2Promotion::k128B),
CU_TENSOR_MAP_L2_PROMOTION_L2_128B);
EXPECT_EQ(GetTensorMapL2Promotion(TmaDescriptor::TmaL2Promotion::k256B),
CU_TENSOR_MAP_L2_PROMOTION_L2_256B);
}

TEST(TmaUtilTest, GetTensorMapFloatOobFillReturnsCorrectFloatOobFill) {
EXPECT_EQ(GetTensorMapFloatOOBFill(TmaDescriptor::TmaFloatOobFill::kNone),
CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE);
EXPECT_EQ(GetTensorMapFloatOOBFill(
TmaDescriptor::TmaFloatOobFill::kNanRequestZeroFma),
CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA);
}

TEST(TmaUtilTest, GetTensorMapInterleaveReturnsCorrectInterleave) {
EXPECT_EQ(GetTensorMapInterleave(TmaDescriptor::TmaInterleave::kNone),
CU_TENSOR_MAP_INTERLEAVE_NONE);
EXPECT_EQ(GetTensorMapInterleave(TmaDescriptor::TmaInterleave::k16B),
CU_TENSOR_MAP_INTERLEAVE_16B);
EXPECT_EQ(GetTensorMapInterleave(TmaDescriptor::TmaInterleave::k32B),
CU_TENSOR_MAP_INTERLEAVE_32B);
}

} // namespace
} // namespace stream_executor::gpu
10 changes: 10 additions & 0 deletions xla/stream_executor/stream_executor.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ limitations under the License.
#include "xla/stream_executor/event.h"
#include "xla/stream_executor/event_based_timer.h"
#include "xla/stream_executor/fft.h"
#include "xla/stream_executor/gpu/tma_metadata.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/kernel_spec.h"
#include "xla/stream_executor/memory_allocation.h"
Expand Down Expand Up @@ -342,6 +343,15 @@ class StreamExecutor {
// Sets the argument logging mode. Returns true if 'mode' is valid.
// The mode is a bitmask of the kLog* constants.
virtual bool SetArgumentLoggingMode(uint64_t mode) { return false; }

// Creates, allocates, and copies a CUtensorMap object for the given TMA
// descriptor. Returns a DeviceMemoryBase pointing to the allocated
// CUtensorMap object to be used as an argument to a kernel.
// Only implemented on CUDA GPUs.
virtual absl::StatusOr<DeviceMemoryBase> CreateTensorMap(
gpu::TmaDescriptor tma_desc, void* global_address) {
return absl::UnimplementedError("Not Implemented");
}
};

template <typename T>
Expand Down

0 comments on commit 2940811

Please sign in to comment.